From bbeaf8b621977d3307cb91b2eeac7fdf1fe95e75 Mon Sep 17 00:00:00 2001 From: Tianjing Li Date: Wed, 6 Nov 2024 09:54:16 -0800 Subject: [PATCH] backend: Tool config major refactoring (#822) * replace with iD * wip * wip * Wip * more wip * Fix coral web unit tests * Resolve coral_web tests * Fix last case * Fix coral-web test * Refactor get_available_tools to use function directly * wip * wip * wip, community tools todo * Fix tests * Lint * wip * Fix all unit tests * Remove makefile change * Fix unit tests * Fixing tests wip * Fix lint * Remove pdbs --- docs/config_details/config_description.md | 2 - docs/custom_tool_guides/tool_guide.md | 50 ++- docs/how_to_guides.md | 2 +- src/backend/chat/custom/custom.py | 16 +- src/backend/chat/custom/tool_calls.py | 4 +- src/backend/cli/constants.py | 18 +- src/backend/cli/main.py | 5 +- .../config/configuration.template.yaml | 8 - src/backend/config/settings.py | 1 - src/backend/config/tools.py | 291 ++---------------- src/backend/main.py | 5 + src/backend/pytest.ini | 4 +- src/backend/routers/auth.py | 17 +- src/backend/routers/organization.py | 5 +- src/backend/routers/tool.py | 15 +- src/backend/schemas/tool.py | 17 +- src/backend/services/auth/strategies/base.py | 7 +- src/backend/services/chat.py | 8 +- src/backend/services/file.py | 6 +- src/backend/services/request_validators.py | 8 +- .../tests/integration/routers/test_agent.py | 14 +- .../tests/unit/chat/test_tool_calls.py | 146 ++++----- .../tests/unit/config/test_deployments.py | 15 +- src/backend/tests/unit/config/test_tools.py | 0 src/backend/tests/unit/configuration.yaml | 1 - src/backend/tests/unit/crud/test_agent.py | 16 +- .../unit/crud/test_agent_tool_metadata.py | 36 ++- src/backend/tests/unit/crud/test_tool_auth.py | 6 +- src/backend/tests/unit/factories/agent.py | 14 +- .../unit/factories/agent_tool_metadata.py | 16 +- src/backend/tests/unit/factories/tool_auth.py | 4 +- src/backend/tests/unit/routers/test_agent.py | 36 +-- src/backend/tests/unit/routers/test_chat.py | 59 +--- src/backend/tests/unit/routers/test_tool.py | 26 +- src/backend/tools/base.py | 42 ++- src/backend/tools/brave_search/tool.py | 26 +- src/backend/tools/calculator.py | 23 +- src/backend/tools/files.py | 58 +++- src/backend/tools/google_drive/tool.py | 27 +- src/backend/tools/google_search.py | 23 +- src/backend/tools/hybrid_search.py | 28 +- src/backend/tools/lang_chain.py | 28 +- src/backend/tools/python_interpreter.py | 31 +- src/backend/tools/slack/tool.py | 25 +- src/backend/tools/tavily_search.py | 52 +++- src/backend/tools/utils/mixins.py | 2 +- src/backend/tools/utils/tools_checkers.py | 20 +- src/backend/tools/web_scrape.py | 28 +- src/community/config/tools.py | 140 +-------- src/community/tools/__init__.py | 5 - src/community/tools/arxiv.py | 25 +- src/community/tools/clinicaltrials.py | 40 ++- src/community/tools/connector.py | 30 +- src/community/tools/llama_index.py | 35 ++- src/community/tools/pub_med.py | 25 +- src/community/tools/wolfram.py | 18 +- .../src/app/(main)/(chat)/Chat.tsx | 6 +- .../cohere-client/generated/schemas.gen.ts | 228 +++++++------- .../cohere-client/generated/services.gen.ts | 11 +- .../src/cohere-client/generated/types.gen.ts | 56 ++-- .../AgentSettingsForm/ToolsStep.tsx | 4 +- .../src/components/Composer/Composer.tsx | 6 +- .../components/Composer/ComposerToolbar.tsx | 4 +- .../components/Composer/DataSourceMenu.tsx | 6 +- .../components/Conversation/Conversation.tsx | 4 +- .../MessagingContainer/AssistantTools.tsx | 6 +- .../assistants_web/src/hooks/use-tools.ts | 12 +- 67 files changed, 995 insertions(+), 957 deletions(-) create mode 100644 src/backend/tests/unit/config/test_tools.py diff --git a/docs/config_details/config_description.md b/docs/config_details/config_description.md index ca80b1ba04..785eb9c936 100644 --- a/docs/config_details/config_description.md +++ b/docs/config_details/config_description.md @@ -23,8 +23,6 @@ - redis - Redis configurations - url - URL of the redis, for example, redis://:redis@redis:6379 - tools - Tool configurations - - enabled_tools - these are the tools that are enabled for the toolkit. The full list of tools can be found in the src/backend/config/tools.py file. - The community tools are listed in the src/community/config/tools.py file. Please note that the tools availability is checked too. - python_interpreter - Python interpreter configurations - url - URL of the python interpreter tool - feature_flags - Feature flags configurations diff --git a/docs/custom_tool_guides/tool_guide.md b/docs/custom_tool_guides/tool_guide.md index d8c37d57b9..e1eff9e9ed 100644 --- a/docs/custom_tool_guides/tool_guide.md +++ b/docs/custom_tool_guides/tool_guide.md @@ -53,7 +53,7 @@ from community.tools import BaseTool class ArxivRetriever(BaseTool): - NAME = "arxiv" + ID = "arxiv" def __init__(self): self.client = ArxivAPIWrapper() @@ -64,6 +64,27 @@ class ArxivRetriever(BaseTool): def is_available(cls) -> bool: return True + @classmethod + # You will need to add a tool definition here + def get_tool_definition(cls) -> ToolDefinition: + return ToolDefinition( + name=cls.ID, + display_name="Arxiv", + implementation=cls, + parameter_definitions={ + "query": { + "description": "Query for retrieval.", + "type": "str", + "required": True, + } + }, + is_visible=False, + is_available=cls.is_available(), + error_message=cls.generate_error_message(), + category=ToolCategory.DataLoader, + description="Retrieves documents from Arxiv.", + ) + # Your tool needs to implement this call() method def call(self, parameters: str, **kwargs: Any) -> List[Dict[str, Any]]: result = self.client.run(parameters) @@ -84,34 +105,9 @@ return [{"text": "The fox is blue", "url": "wikipedia.org/foxes", "title": "Colo Next, add your tool class to the init file by locating it in `src/community/tools/__init__.py`. Import your tool here, then add it to the `__all__` list. -To enable your tool, you will need to go to the `configuration.yaml` file and add your tool's name to the list of `enabled_tools`. This tool name will correspond to the one defined in the `NAME` attribute of your class. - Finally, you will need to add your tool definition to the config file. Locate it in `src/community/config/tools.py`, and import your tool at the top with `from backend.tools import ..`. -In the ToolName enum, add your tool as an enum value. For example, `My_Tool = MyTool.NAME`. - -In the `ALL_TOOLS` dictionary, add your tool definition. This should look like: - -```python - ToolName.My_Tool: ManagedTool( # THE TOOLNAME HERE CORRESPONDS TO THE ENUM YOU DEFINED EARLIER - display_name="My Tool", - implementation=MyTool, # THIS IS THE CLASS YOU IMPORTED AT THE TOP - parameter_definitions={ # THESE ARE PARAMS THE MODEL WILL SEND TO YOUR TOOL, ADJUST AS NEEDED - "query": { - "description": "Query to search with", - "type": "str", - "required": True, - } - }, - is_visible=True, - is_available=MyTool.is_available(), - auth_implementation=None, # EMPTY IF NO AUTH NEEDED - error_message="Something went wrong", - category=Category.DataLoader, # CHECK CATEGORY ENUM FOR POSSIBLE VALUES - description="An example definition to get you started.", - ), -``` - +Finally, to enable your tool, add your tool as an enum value. For example, `My_Tool = MyToolClass`. ## Step 5: Test Your Tool! diff --git a/docs/how_to_guides.md b/docs/how_to_guides.md index b498f8cb66..5aee501b91 100644 --- a/docs/how_to_guides.md +++ b/docs/how_to_guides.md @@ -48,7 +48,7 @@ The core chat interface is the Coral frontend. To implement your own interface: If you have already created a [connector](https://docs.cohere.com/docs/connectors), you can utilize it within the toolkit by following these steps: 1. Configure your connector using `ConnectorRetriever`. -2. Add its definition in [community/config/tools.py](https://github.com/cohere-ai/cohere-toolkit/blob/main/src/community/config/tools.py), following the `Arxiv` implementation, using the category `Category.DataLoader`. +2. Add its definition in [community/config/tools.py](https://github.com/cohere-ai/cohere-toolkit/blob/main/src/community/config/tools.py), following the `Arxiv` implementation, using the category `ToolCategory.DataLoader`. You can now use both the Coral frontend and API with your connector. diff --git a/src/backend/chat/custom/custom.py b/src/backend/chat/custom/custom.py index 16ce04e3b5..1fb5e78240 100644 --- a/src/backend/chat/custom/custom.py +++ b/src/backend/chat/custom/custom.py @@ -6,20 +6,19 @@ from backend.chat.custom.tool_calls import async_call_tools from backend.chat.custom.utils import get_deployment from backend.chat.enums import StreamEvent -from backend.config.tools import AVAILABLE_TOOLS +from backend.config.tools import get_available_tools from backend.database_models.file import File from backend.model_deployments.base import BaseDeployment from backend.schemas.chat import ChatMessage, ChatRole, EventState from backend.schemas.cohere_chat import CohereChatRequest from backend.schemas.context import Context -from backend.schemas.tool import Category, Tool +from backend.schemas.tool import Tool, ToolCategory from backend.services.chat import check_death_loop from backend.services.file import get_file_service from backend.tools.utils.tools_checkers import tool_has_category MAX_STEPS = 15 - class CustomChat(BaseChat): """Custom chat flow not using integrations for models.""" @@ -163,7 +162,7 @@ async def call_chat( file_reader_tools_names = [] if managed_tools: chat_request.tools = managed_tools - file_reader_tools_names = [tool.name for tool in managed_tools_full_schema if tool_has_category(tool, Category.FileLoader)] + file_reader_tools_names = [tool.name for tool in managed_tools_full_schema if tool_has_category(tool, ToolCategory.FileLoader)] # Get files if available all_files = [] @@ -248,17 +247,18 @@ def update_chat_history_with_tool_results( chat_request.chat_history.extend(tool_results) def get_managed_tools(self, chat_request: CohereChatRequest, full_schema=False): + available_tools = get_available_tools() if full_schema: return [ - AVAILABLE_TOOLS.get(tool.name) + available_tools.get(tool.name) for tool in chat_request.tools - if AVAILABLE_TOOLS.get(tool.name) + if available_tools.get(tool.name) ] return [ - Tool(**AVAILABLE_TOOLS.get(tool.name).model_dump()) + Tool(**available_tools.get(tool.name).model_dump()) for tool in chat_request.tools - if AVAILABLE_TOOLS.get(tool.name) + if available_tools.get(tool.name) ] def add_files_to_chat_history( diff --git a/src/backend/chat/custom/tool_calls.py b/src/backend/chat/custom/tool_calls.py index c5fab560bf..2c003878d5 100644 --- a/src/backend/chat/custom/tool_calls.py +++ b/src/backend/chat/custom/tool_calls.py @@ -5,7 +5,7 @@ from sqlalchemy.orm import Session from backend.chat.collate import rerank_and_chunk, to_dict -from backend.config.tools import AVAILABLE_TOOLS +from backend.config.tools import get_available_tools from backend.model_deployments.base import BaseDeployment from backend.schemas.context import Context from backend.services.logger.utils import LoggerFactory @@ -76,7 +76,7 @@ async def _call_tool_async( tool_call: dict, deployment_model: BaseDeployment, ) -> List[Dict[str, Any]]: - tool = AVAILABLE_TOOLS.get(tool_call["name"]) + tool = get_available_tools().get(tool_call["name"]) if not tool: logger.info( event=f"[Custom Chat] Tool not included in tools parameter: {tool_call['name']}", diff --git a/src/backend/cli/constants.py b/src/backend/cli/constants.py index e3f7dbd642..5d3facd9bd 100644 --- a/src/backend/cli/constants.py +++ b/src/backend/cli/constants.py @@ -22,9 +22,11 @@ class BuildTarget(StrEnum): PROD = "prod" -class ToolName(StrEnum): +class Tool(StrEnum): PythonInterpreter = "Python Interpreter" TavilyInternetSearch = "Tavily Internet Search" + Wolfram_Alpha = "Wolfram Alpha" + WELCOME_MESSAGE = r""" @@ -50,18 +52,28 @@ class ToolName(StrEnum): TOOLS = { - ToolName.PythonInterpreter: { + Tool.PythonInterpreter: { "secrets": { "PYTHON_INTERPRETER_URL": PYTHON_INTERPRETER_URL_DEFAULT, }, }, - ToolName.TavilyInternetSearch: { + Tool.TavilyInternetSearch: { "secrets": { "TAVILY_API_KEY": None, }, }, } +# For main.py cli setup script +COMMUNITY_TOOLS = { + Tool.Wolfram_Alpha: { + "secrets": { + "WOLFRAM_APP_ID": None, # default value + }, + }, +} + + ENV_YAML_CONFIG_MAPPING = { "USE_COMMUNITY_FEATURES": { "type": "config", diff --git a/src/backend/cli/main.py b/src/backend/cli/main.py index 9dbe44a62a..40d6eb7e79 100755 --- a/src/backend/cli/main.py +++ b/src/backend/cli/main.py @@ -1,6 +1,6 @@ import argparse -from backend.cli.constants import TOOLS +from backend.cli.constants import COMMUNITY_TOOLS, TOOLS from backend.cli.prompts import ( PROMPTS, community_tools_prompt, @@ -23,7 +23,6 @@ from community.config.deployments import ( AVAILABLE_MODEL_DEPLOYMENTS as COMMUNITY_DEPLOYMENTS_SETUP, ) -from community.config.tools import COMMUNITY_TOOLS_SETUP def start(): @@ -43,7 +42,7 @@ def start(): # ENABLE COMMUNITY TOOLS use_community_features = args.use_community and community_tools_prompt(secrets) if use_community_features: - TOOLS.update(COMMUNITY_TOOLS_SETUP) + TOOLS.update(COMMUNITY_TOOLS) # SET UP TOOLS for name, configs in TOOLS.items(): diff --git a/src/backend/config/configuration.template.yaml b/src/backend/config/configuration.template.yaml index b5ffb72e17..a2034de59a 100644 --- a/src/backend/config/configuration.template.yaml +++ b/src/backend/config/configuration.template.yaml @@ -20,14 +20,6 @@ database: redis: url: redis://:redis@redis:6379 tools: - enabled_tools: - - wikipedia - - search_file - - read_file - - toolkit_python_interpreter - - toolkit_calculator - - hybrid_web_search - - web_scrape hybrid_web_search: # List of web search tool names, eg: google_web_search, tavily_web_search enabled_web_searches: diff --git a/src/backend/config/settings.py b/src/backend/config/settings.py index a60be0942b..63e1853b7c 100644 --- a/src/backend/config/settings.py +++ b/src/backend/config/settings.py @@ -195,7 +195,6 @@ class HybridWebSearchSettings(BaseSettings, BaseModel): class ToolSettings(BaseSettings, BaseModel): model_config = SETTINGS_CONFIG - enabled_tools: Optional[List[str]] = None python_interpreter: Optional[PythonToolSettings] = Field( default=PythonToolSettings() diff --git a/src/backend/config/tools.py b/src/backend/config/tools.py index 6a5d7a13b4..082978adaf 100644 --- a/src/backend/config/tools.py +++ b/src/backend/config/tools.py @@ -1,20 +1,18 @@ -from enum import StrEnum +from enum import Enum from backend.config.settings import Settings -from backend.schemas.tool import Category, ManagedTool +from backend.schemas.tool import ToolDefinition from backend.services.logger.utils import LoggerFactory from backend.tools import ( BraveWebSearch, Calculator, GoogleDrive, - GoogleDriveAuth, GoogleWebSearch, HybridWebSearch, LangChainWikiRetriever, PythonInterpreter, ReadFileTool, SearchFileTool, - SlackAuth, SlackTool, TavilyWebSearch, WebScrapeTool, @@ -23,270 +21,41 @@ logger = LoggerFactory().get_logger() """ -List of available tools. Each tool should have a name, implementation, is_visible and category. -They can also have kwargs if necessary. - -You can switch the visibility of a tool by changing the is_visible parameter to True or False. -If a tool is not visible, it will not be shown in the frontend. - -If you want to add a new tool, check the instructions on how to implement a retriever in the documentation. -Don't forget to add the implementation to this AVAILABLE_TOOLS dictionary! +Tool Name enum, mapping to the tool's main implementation class. """ - -class ToolName(StrEnum): - Wiki_Retriever_LangChain = LangChainWikiRetriever.NAME - Search_File = SearchFileTool.NAME - Read_File = ReadFileTool.NAME - Python_Interpreter = PythonInterpreter.NAME - Calculator = Calculator.NAME - Google_Drive = GoogleDrive.NAME - Web_Scrape = WebScrapeTool.NAME - Tavily_Web_Search = TavilyWebSearch.NAME - Google_Web_Search = GoogleWebSearch.NAME - Brave_Web_Search = BraveWebSearch.NAME - Hybrid_Web_Search = HybridWebSearch.NAME - Slack = SlackTool.NAME - - -ALL_TOOLS = { - ToolName.Search_File: ManagedTool( - display_name="Search File", - implementation=SearchFileTool, - parameter_definitions={ - "search_query": { - "description": "Textual search query to search over the file's content for", - "type": "str", - "required": True, - }, - "files": { - "description": "A list of files represented as tuples of (filename, file ID) to search over", - "type": "list[tuple[str, str]]", - "required": True, - }, - }, - is_visible=True, - is_available=SearchFileTool.is_available(), - error_message="SearchFileTool not available.", - category=Category.FileLoader, - description="Performs a search over a list of one or more of the attached files for a textual search query", - ), - ToolName.Read_File: ManagedTool( - display_name="Read Document", - implementation=ReadFileTool, - parameter_definitions={ - "file": { - "description": "A file represented as a tuple (filename, file ID) to read over", - "type": "tuple[str, str]", - "required": True, - } - }, - is_visible=True, - is_available=ReadFileTool.is_available(), - error_message="ReadFileTool not available.", - category=Category.FileLoader, - description="Returns the textual contents of an uploaded file, broken up in text chunks.", - ), - ToolName.Python_Interpreter: ManagedTool( - display_name="Python Interpreter", - implementation=PythonInterpreter, - parameter_definitions={ - "code": { - "description": ( - "Python code to execute using the Python interpreter with no internet access. " - "Do not generate code that tries to open files directly, instead use file contents passed to the interpreter, " - "then print output or save output to a file." - ), - "type": "str", - "required": True, - } - }, - is_visible=True, - is_available=PythonInterpreter.is_available(), - error_message="PythonInterpreterFunctionTool not available, please make sure to set the tools.python_interpreter.url variable in your configuration.yaml", - category=Category.Function, - description="Runs python code in a sandbox.", - ), - ToolName.Wiki_Retriever_LangChain: ManagedTool( - display_name="Wikipedia", - implementation=LangChainWikiRetriever, - parameter_definitions={ - "query": { - "description": "Query for retrieval.", - "type": "str", - "required": True, - } - }, - kwargs={"chunk_size": 300, "chunk_overlap": 0}, - is_visible=True, - is_available=LangChainWikiRetriever.is_available(), - error_message="LangChainWikiRetriever not available.", - category=Category.DataLoader, - description="Retrieves documents from Wikipedia using LangChain.", - ), - ToolName.Calculator: ManagedTool( - display_name="Calculator", - implementation=Calculator, - parameter_definitions={ - "code": { - "description": "The expression for the calculator to evaluate, it should be a valid mathematical expression.", - "type": "str", - "required": True, - } - }, - is_visible=False, - is_available=Calculator.is_available(), - error_message="Calculator tool not available.", - category=Category.Function, - description="This is a powerful multi-purpose calculator which is capable of a wide array of math calculations.", - ), - ToolName.Google_Drive: ManagedTool( - display_name="Google Drive", - implementation=GoogleDrive, - parameter_definitions={ - "query": { - "description": "Query to search Google Drive documents with.", - "type": "str", - "required": True, - } - }, - is_visible=True, - is_available=GoogleDrive.is_available(), - auth_implementation=GoogleDriveAuth, - error_message="Google Drive not available, please enable it in the GoogleDrive tool class.", - category=Category.DataLoader, - description="Returns a list of relevant document snippets for the user's google drive.", - ), - ToolName.Web_Scrape: ManagedTool( - name=ToolName.Web_Scrape, - display_name="Web Scrape", - implementation=WebScrapeTool, - parameter_definitions={ - "url": { - "description": "The url to scrape.", - "type": "str", - "required": True, - }, - "query": { - "description": "The query to use to select the most relevant passages to return. Using an empty string will return the passages in the order they appear on the webpage", - "type": "str", - "required": False, - }, - }, - is_visible=True, - is_available=WebScrapeTool.is_available(), - error_message="WebScrapeTool not available.", - category=Category.DataLoader, - description="Scrape and returns the textual contents of a webpage as a list of passages for a given url.", - ), - ToolName.Tavily_Web_Search: ManagedTool( - display_name="Web Search", - implementation=TavilyWebSearch, - parameter_definitions={ - "query": { - "description": "Query to search the internet with", - "type": "str", - "required": True, - } - }, - is_visible=False, - is_available=TavilyWebSearch.is_available(), - error_message="TavilyWebSearch not available, please make sure to set the tools.tavily_web_search.api_key variable in your secrets.yaml", - category=Category.WebSearch, - description="Returns a list of relevant document snippets for a textual query retrieved from the internet.", - ), - ToolName.Google_Web_Search: ManagedTool( - display_name="Google Web Search", - implementation=GoogleWebSearch, - parameter_definitions={ - "query": { - "description": "A search query for the Google search engine.", - "type": "str", - "required": True, - } - }, - is_visible=False, - is_available=GoogleWebSearch.is_available(), - error_message="Google Web Search not available, please enable it in the GoogleWebSearch tool class.", - category=Category.WebSearch, - description="Returns relevant results by performing a Google web search.", - ), - ToolName.Brave_Web_Search: ManagedTool( - display_name="Brave Web Search", - implementation=BraveWebSearch, - parameter_definitions={ - "query": { - "description": "Query for retrieval.", - "type": "str", - "required": True, - } - }, - is_visible=False, - is_available=BraveWebSearch.is_available(), - error_message="BraveWebSearch not available, please make sure to set the tools.brave_web_search.api_key variable in your secrets.yaml", - category=Category.WebSearch, - description="Returns a list of relevant document snippets for a textual query retrieved from the internet using Brave Search.", - ), - ToolName.Hybrid_Web_Search: ManagedTool( - display_name="Hybrid Web Search", - implementation=HybridWebSearch, - parameter_definitions={ - "query": { - "description": "Query for retrieval.", - "type": "str", - "required": True, - } - }, - is_visible=True, - is_available=HybridWebSearch.is_available(), - error_message="HybridWebSearch not available, please make sure to set at least one option in the tools.hybrid_web_search.enabled_web_searches variable in your configuration.yaml", - category=Category.WebSearch, - description="Returns a list of relevant document snippets for a textual query retrieved from the internet using a mix of any existing Web Search tools.", - ), - ToolName.Slack: ManagedTool( - display_name="Slack", - implementation=SlackTool, - parameter_definitions={ - "query": { - "description": "Query to search slack.", - "type": "str", - "required": True, - } - }, - is_visible=True, - is_available=SlackTool.is_available(), - auth_implementation=SlackAuth, - error_message="SlackTool not available, please enable it in the SlackTool class.", - category=Category.DataLoader, - description="Returns a list of relevant document snippets from slack.", - ), -} - - -def get_available_tools() -> dict[ToolName, dict]: +class Tool(Enum): + Wiki_Retriever_LangChain = LangChainWikiRetriever + Read_File = ReadFileTool + Search_File = SearchFileTool + Python_Interpreter = PythonInterpreter + Calculator = Calculator + Google_Drive = GoogleDrive + Web_Scrape = WebScrapeTool + Tavily_Web_Search = TavilyWebSearch + Google_Web_Search = GoogleWebSearch + Brave_Web_Search = BraveWebSearch + Hybrid_Web_Search = HybridWebSearch + Slack = SlackTool + + +def get_available_tools() -> dict[str, ToolDefinition]: + # Get list of implementations from Tool Enum + tool_classes = [tool.value for tool in Tool] + # Generate dictionary of ToolDefinitions keyed by Tool ID + tools = { + tool.ID: tool.get_tool_definition() for tool in tool_classes + } + + # Handle adding Community-implemented tools use_community_tools = Settings().get('feature_flags.use_community_features') - - tools = ALL_TOOLS.copy() if use_community_tools: try: - from community.config.tools import COMMUNITY_TOOLS - - tools = ALL_TOOLS.copy() - tools.update(COMMUNITY_TOOLS) + from community.config.tools import get_community_tools + community_tools = get_community_tools() + tools.update(community_tools) except ImportError: logger.warning( event="[Tools] Error loading tools: Community tools not available." ) - for tool in tools.values(): - # Conditionally set error message - tool.error_message = tool.error_message if not tool.is_available else None - # Retrieve name - tool.name = tool.implementation.NAME - - enabled_tools = Settings().get('tools.enabled_tools') - if enabled_tools is not None and len(enabled_tools) > 0: - tools = {key: value for key, value in tools.items() if key in enabled_tools} return tools - - -AVAILABLE_TOOLS = get_available_tools() diff --git a/src/backend/main.py b/src/backend/main.py index 9569cde052..3bdd288a30 100644 --- a/src/backend/main.py +++ b/src/backend/main.py @@ -1,3 +1,5 @@ +import logging + from alembic.command import upgrade from alembic.config import Config from dotenv import load_dotenv @@ -29,6 +31,9 @@ from backend.services.context import ContextMiddleware, get_context from backend.services.logger.middleware import LoggingMiddleware +# Only show errors for Pydantic +logging.getLogger('pydantic').setLevel(logging.ERROR) + load_dotenv() # CORS Origins diff --git a/src/backend/pytest.ini b/src/backend/pytest.ini index e8ad063ce5..3ba10e4a74 100644 --- a/src/backend/pytest.ini +++ b/src/backend/pytest.ini @@ -1,3 +1,5 @@ [pytest] env = - DATABASE_URL=postgresql://postgres:postgres@localhost:5433/postgres \ No newline at end of file + DATABASE_URL=postgresql://postgres:postgres@localhost:5433/postgres +filterwarnings = + ignore::UserWarning:pydantic.* diff --git a/src/backend/routers/auth.py b/src/backend/routers/auth.py index b726fef898..3e3d6b52a7 100644 --- a/src/backend/routers/auth.py +++ b/src/backend/routers/auth.py @@ -9,7 +9,7 @@ from backend.config.auth import ENABLED_AUTH_STRATEGY_MAPPING from backend.config.routers import RouterName from backend.config.settings import Settings -from backend.config.tools import AVAILABLE_TOOLS, ToolName +from backend.config.tools import Tool, get_available_tools from backend.crud import blacklist as blacklist_crud from backend.database_models import Blacklist from backend.database_models.database import DBSessionDep @@ -295,8 +295,9 @@ def log_and_redirect_err(error_message: str): err = f"Tool Auth cache {tool_auth_cache} does not contain user_id or tool_id." log_and_redirect_err(err) - if tool_id in AVAILABLE_TOOLS: - tool = AVAILABLE_TOOLS.get(tool_id) + available_tools = get_available_tools() + if tool_id in available_tools: + tool = available_tools.get(tool_id) err = None # Tool not found @@ -336,7 +337,7 @@ async def delete_tool_auth( If completed, the corresponding ToolAuth for the requesting user is removed from the DB. Args: - tool_id (str): Tool ID to be deleted for the user. (eg. google_drive) Should be one of the values listed in the ToolName string enum class. + tool_id (str): Tool ID to be deleted for the user. (eg. google_drive) Should be one of the values listed in the Tool string enum class. request (Request): current Request object. session (DBSessionDep): Database session. ctx (Context): Context object. @@ -356,16 +357,16 @@ async def delete_tool_auth( if user_id is None or user_id == "" or user_id == "default": logger.error_and_raise_http_exception(event="User ID not found.") - if tool_id not in [tool_name.value for tool_name in ToolName]: + if tool_id not in [tool_name.value for tool_name in Tool]: logger.error_and_raise_http_exception( - event="tool_id must be present in the path of the request and must be a member of the ToolName string enum class.", + event="tool_id must be present in the path of the request and must be a member of the Tool string enum class.", ) - tool = AVAILABLE_TOOLS.get(tool_id) + tool = get_available_tools().get(tool_id) if tool is None: logger.error_and_raise_http_exception( - event=f"Tool {tool_id} is not available in AVAILABLE_TOOLS." + event=f"Tool {tool_id} is not available." ) if tool.auth_implementation is None: diff --git a/src/backend/routers/organization.py b/src/backend/routers/organization.py index f1a14c2512..6c252f7c6c 100644 --- a/src/backend/routers/organization.py +++ b/src/backend/routers/organization.py @@ -87,9 +87,10 @@ def get_organization( Args: organization_id (str): Tool ID. session (DBSessionDep): Database session. + ctx: Context. Returns: - ManagedTool: Organization with the given ID. + Organization: Organization with the given ID. """ organization = organization_crud.get_organization(session, organization_id) if not organization: @@ -135,7 +136,7 @@ def list_organizations( session (DBSessionDep): Database session. Returns: - list[ManagedTool]: List of available organizations. + list[Organization]: List of available organizations. """ all_organizations = organization_crud.get_organizations(session) return all_organizations diff --git a/src/backend/routers/tool.py b/src/backend/routers/tool.py index a0e95fb3ba..b9078ebc0c 100644 --- a/src/backend/routers/tool.py +++ b/src/backend/routers/tool.py @@ -1,10 +1,10 @@ from fastapi import APIRouter, Depends, Request from backend.config.routers import RouterName -from backend.config.tools import AVAILABLE_TOOLS +from backend.config.tools import get_available_tools from backend.database_models.database import DBSessionDep from backend.schemas.context import Context -from backend.schemas.tool import ManagedTool +from backend.schemas.tool import ToolDefinition from backend.services.agent import validate_agent_exists from backend.services.context import get_context @@ -12,13 +12,13 @@ router.name = RouterName.TOOL -@router.get("", response_model=list[ManagedTool]) +@router.get("", response_model=list[ToolDefinition]) def list_tools( request: Request, session: DBSessionDep, agent_id: str | None = None, ctx: Context = Depends(get_context), -) -> list[ManagedTool]: +) -> list[ToolDefinition]: """ List all available tools. @@ -28,19 +28,20 @@ def list_tools( agent_id (str): Agent ID. ctx (Context): Context object. Returns: - list[ManagedTool]: List of available tools. + list[ToolDefinition]: List of available tools. """ user_id = ctx.get_user_id() logger = ctx.get_logger() - all_tools = AVAILABLE_TOOLS.values() + available_tools = get_available_tools() + all_tools = list(available_tools.values()) if agent_id is not None: agent_tools = [] agent = validate_agent_exists(session, agent_id, user_id) for tool in agent.tools: - agent_tools.append(AVAILABLE_TOOLS[tool]) + agent_tools.append(available_tools[tool]) all_tools = agent_tools for tool in all_tools: diff --git a/src/backend/schemas/tool.py b/src/backend/schemas/tool.py index ec92090ae8..d8fa884bd7 100644 --- a/src/backend/schemas/tool.py +++ b/src/backend/schemas/tool.py @@ -4,30 +4,25 @@ from pydantic import BaseModel, Field -class Category(StrEnum): +class ToolCategory(StrEnum): DataLoader = "Data loader" FileLoader = "File loader" Function = "Function" WebSearch = "Web search" -class ToolInput(BaseModel): - pass - - class Tool(BaseModel): name: Optional[str] = "" - display_name: str = "" - description: Optional[str] = "" parameter_definitions: Optional[dict] = {} - -class ManagedTool(Tool): +class ToolDefinition(Tool): + display_name: str = "" + description: str = "" + error_message: Optional[str] = "" kwargs: dict = {} is_visible: bool = False is_available: bool = False - error_message: Optional[str] = "" - category: Category = Category.DataLoader + category: ToolCategory = ToolCategory.DataLoader is_auth_required: bool = False # Per user auth_url: Optional[str] = "" # Per user diff --git a/src/backend/services/auth/strategies/base.py b/src/backend/services/auth/strategies/base.py index 38112519a2..e305545824 100644 --- a/src/backend/services/auth/strategies/base.py +++ b/src/backend/services/auth/strategies/base.py @@ -43,13 +43,14 @@ class BaseOAuthStrategy: def __init__(self, *args, **kwargs): self._post_init_check() - def _post_init_check(self): + @classmethod + def _post_init_check(cls): if any( [ - self.NAME is None, + cls.NAME is None, ] ): - raise ValueError(f"{self.__name__} must have NAME attribute defined.") + raise ValueError(f"{cls.__name__} must have NAME attribute defined.") @abstractmethod def get_client_id(self, **kwargs: Any): diff --git a/src/backend/services/chat.py b/src/backend/services/chat.py index bf1d560a7b..8e8abc6e6e 100644 --- a/src/backend/services/chat.py +++ b/src/backend/services/chat.py @@ -9,7 +9,7 @@ from backend.chat.collate import to_dict from backend.chat.enums import StreamEvent -from backend.config.tools import AVAILABLE_TOOLS +from backend.config.tools import get_available_tools from backend.crud import agent_tool_metadata as agent_tool_metadata_crud from backend.crud import conversation as conversation_crud from backend.crud import message as message_crud @@ -156,7 +156,7 @@ def process_chat( tools = chat_request.tools managed_tools = ( - len([tool.name for tool in tools if tool.name in AVAILABLE_TOOLS]) > 0 + len([tool.name for tool in tools if tool.name in get_available_tools()]) > 0 ) return ( @@ -253,7 +253,7 @@ def process_message_regeneration( ) managed_tools = ( - len([tool.name for tool in chat_request.tools if tool.name in AVAILABLE_TOOLS]) > 0 + len([tool.name for tool in chat_request.tools if tool.name in get_available_tools()]) > 0 ) return ( @@ -313,7 +313,7 @@ def is_custom_tool_call(chat_response: BaseChatRequest) -> bool: # check if any of the tools is not in the available tools for tool in chat_response.tools: - if tool.name not in AVAILABLE_TOOLS: + if tool.name not in get_available_tools(): return True return False diff --git a/src/backend/services/file.py b/src/backend/services/file.py index f1d8b02234..d52212efc7 100644 --- a/src/backend/services/file.py +++ b/src/backend/services/file.py @@ -132,7 +132,7 @@ def get_files_by_agent_id( Returns: list[File]: The files that were created """ - from backend.config.tools import ToolName + from backend.config.tools import Tool from backend.tools.files import FileToolsArtifactTypes agent = validate_agent_exists(session, agent_id, user_id) @@ -144,8 +144,8 @@ def get_files_by_agent_id( ( tool_metadata.artifacts for tool_metadata in agent_tool_metadata - if tool_metadata.tool_name == ToolName.Read_File - or tool_metadata.tool_name == ToolName.Search_File + if tool_metadata.tool_name == Tool.Read_File.value.ID + or tool_metadata.tool_name == Tool.Search_File.value.ID ), [], # Default value if the generator is empty ) diff --git a/src/backend/services/request_validators.py b/src/backend/services/request_validators.py index badb2b4369..21d6012628 100644 --- a/src/backend/services/request_validators.py +++ b/src/backend/services/request_validators.py @@ -8,7 +8,7 @@ find_config_by_deployment_id, find_config_by_deployment_name, ) -from backend.config.tools import AVAILABLE_TOOLS +from backend.config.tools import get_available_tools from backend.crud import agent as agent_crud from backend.crud import conversation as conversation_crud from backend.crud import deployment as deployment_crud @@ -212,7 +212,7 @@ async def validate_chat_request(session: DBSessionDep, request: Request): if not tools: return - managed_tools = [tool["name"] for tool in tools if tool["name"] in AVAILABLE_TOOLS] + managed_tools = [tool["name"] for tool in tools if tool["name"] in get_available_tools()] if managed_tools and len(tools) != len(managed_tools): raise HTTPException( status_code=400, detail="Cannot mix both managed and custom tools" @@ -288,7 +288,7 @@ async def validate_create_agent_request(session: DBSessionDep, request: Request) tools = body.get("tools") if tools: for tool in tools: - if tool not in AVAILABLE_TOOLS: + if tool not in get_available_tools(): raise HTTPException(status_code=404, detail=f"Tool {tool} not found.") name = body.get("name") @@ -339,7 +339,7 @@ async def validate_update_agent_request(session: DBSessionDep, request: Request) tools = body.get("tools") if tools: for tool in tools: - if tool not in AVAILABLE_TOOLS: + if tool not in get_available_tools(): logger.error(event="Tool not found.", tool=tool) raise HTTPException(status_code=404, detail=f"Tool {tool} not found.") diff --git a/src/backend/tests/integration/routers/test_agent.py b/src/backend/tests/integration/routers/test_agent.py index 9661606fe2..9ba0be0649 100644 --- a/src/backend/tests/integration/routers/test_agent.py +++ b/src/backend/tests/integration/routers/test_agent.py @@ -3,7 +3,7 @@ from sqlalchemy.orm import Session from backend.config.deployments import ModelDeploymentName -from backend.config.tools import ToolName +from backend.config.tools import Tool from backend.database_models.agent import Agent from backend.database_models.agent_tool_metadata import AgentToolMetadata from backend.tests.unit.factories import get_factory @@ -18,7 +18,7 @@ def test_create_agent(session_client: TestClient, session: Session, user) -> Non "temperature": 0.5, "model": "command-r-plus", "deployment": ModelDeploymentName.CoherePlatform, - "tools": [ToolName.Calculator, ToolName.Search_File, ToolName.Read_File], + "tools": [Tool.Calculator.value.ID, Tool.Search_File.value.ID, Tool.Read_File.value.ID], } response = session_client.post( @@ -59,10 +59,10 @@ def test_create_agent_with_tool_metadata( "temperature": 0.5, "model": "command-r-plus", "deployment": ModelDeploymentName.CoherePlatform, - "tools": [ToolName.Google_Drive, ToolName.Search_File], + "tools": [Tool.Google_Drive.value.ID, Tool.Search_File.value.ID], "tools_metadata": [ { - "tool_name": ToolName.Google_Drive, + "tool_name": Tool.Google_Drive.value.ID, "artifacts": [ { "name": "/folder", @@ -72,7 +72,7 @@ def test_create_agent_with_tool_metadata( ], }, { - "tool_name": ToolName.Search_File, + "tool_name": Tool.Search_File.value.ID, "artifacts": [ { "name": "file.txt", @@ -96,11 +96,11 @@ def test_create_agent_with_tool_metadata( .all() ) assert len(tool_metadata) == 2 - assert tool_metadata[0].tool_name == ToolName.Google_Drive + assert tool_metadata[0].tool_name == Tool.Google_Drive.value.ID assert tool_metadata[0].artifacts == [ {"name": "/folder", "ids": "folder1", "type": "folder_ids"}, ] - assert tool_metadata[1].tool_name == ToolName.Search_File + assert tool_metadata[1].tool_name == Tool.Search_File.value.ID assert tool_metadata[1].artifacts == [ {"name": "file.txt", "ids": "file1", "type": "file_ids"} ] diff --git a/src/backend/tests/unit/chat/test_tool_calls.py b/src/backend/tests/unit/chat/test_tool_calls.py index e30173d295..b161049de3 100644 --- a/src/backend/tests/unit/chat/test_tool_calls.py +++ b/src/backend/tests/unit/chat/test_tool_calls.py @@ -6,16 +6,22 @@ from fastapi import HTTPException from backend.chat.custom.tool_calls import async_call_tools -from backend.config.tools import AVAILABLE_TOOLS, ToolName -from backend.schemas.tool import ManagedTool +from backend.config.tools import Tool +from backend.schemas.tool import ToolDefinition from backend.services.context import Context from backend.tests.unit.model_deployments.mock_deployments import MockCohereDeployment from backend.tools.base import BaseTool -def test_async_call_tools_success() -> None: +@pytest.fixture +def mock_get_available_tools(): + with patch("backend.chat.custom.tool_calls.get_available_tools") as mock: + yield mock + + +def test_async_call_tools_success(mock_get_available_tools) -> None: class MockCalculator(BaseTool): - NAME = "toolkit_calculator" + ID = "toolkit_calculator" async def call( self, parameters: dict, ctx: Any, **kwargs: Any @@ -26,29 +32,28 @@ async def call( chat_history = [ { "tool_calls": [ - {"name": "toolkit_calculator", "parameters": {"expression": "6*7"}} + {"name": "toolkit_calculator", "parameters": {"code": "6*7"}} ] } ] - MOCKED_TOOLS = {ToolName.Calculator: ManagedTool(implementation=MockCalculator)} - with patch.dict(AVAILABLE_TOOLS, MOCKED_TOOLS): - results = asyncio.run( - async_call_tools(chat_history, MockCohereDeployment(), ctx) - ) - assert results == [ - { - "call": { - "name": "toolkit_calculator", - "parameters": {"expression": "6*7"}, - }, - "outputs": [{"result": 42}], - } - ] - - -def test_async_call_tools_failure() -> None: + mock_get_available_tools.return_value = {Tool.Calculator.value.ID: ToolDefinition(implementation=MockCalculator)} + results = asyncio.run( + async_call_tools(chat_history, MockCohereDeployment(), ctx) + ) + assert results == [ + { + "call": { + "name": "toolkit_calculator", + "parameters": {"code": "6*7"}, + }, + "outputs": [{"result": 42}], + } + ] + + +def test_async_call_tools_failure(mock_get_available_tools) -> None: class MockCalculator(BaseTool): - NAME = "toolkit_calculator" + ID = "toolkit_calculator" async def call( self, parameters: dict, ctx: Any, **kwargs: Any @@ -59,32 +64,31 @@ async def call( chat_history = [ { "tool_calls": [ - {"name": "toolkit_calculator", "parameters": {"expression": "6*7"}} + {"name": "toolkit_calculator", "parameters": {"code": "6*7"}} ] } ] - MOCKED_TOOLS = {ToolName.Calculator: ManagedTool(implementation=MockCalculator)} - with patch.dict(AVAILABLE_TOOLS, MOCKED_TOOLS): - results = asyncio.run( - async_call_tools(chat_history, MockCohereDeployment(), ctx) - ) - assert results == [ - { - "call": { - "name": "toolkit_calculator", - "parameters": {"expression": "6*7"}, - }, - "outputs": [ - {"error": "Calculator failed", "status_code": 500, "success": False} - ], + mock_get_available_tools.return_value = {Tool.Calculator.value.ID: ToolDefinition(implementation=MockCalculator)} + results = asyncio.run( + async_call_tools(chat_history, MockCohereDeployment(), ctx) + ) + assert results == [ + { + "call": { + "name": "toolkit_calculator", + "parameters": {"code": "6*7"}, }, - ] + "outputs": [ + {"error": "Calculator failed", "status_code": 500, "success": False} + ], + }, + ] @patch("backend.chat.custom.tool_calls.TIMEOUT_SECONDS", 1) -def test_async_call_tools_timeout() -> None: +def test_async_call_tools_timeout(mock_get_available_tools) -> None: class MockCalculator(BaseTool): - NAME = "toolkit_calculator" + ID = "toolkit_calculator" async def call( self, parameters: dict, ctx: Any, **kwargs: Any @@ -96,23 +100,23 @@ async def call( chat_history = [ { "tool_calls": [ - {"name": "toolkit_calculator", "parameters": {"expression": "6*7"}} + {"name": "toolkit_calculator", "parameters": {"code": "6*7"}} ] } ] - MOCKED_TOOLS = {ToolName.Calculator: ManagedTool(implementation=MockCalculator)} - with patch.dict(AVAILABLE_TOOLS, MOCKED_TOOLS): - with pytest.raises(HTTPException) as excinfo: - asyncio.run(async_call_tools(chat_history, MockCohereDeployment(), ctx)) - assert str(excinfo.value.status_code) == "500" - assert ( - str(excinfo.value.detail) == "Timeout while calling tools with timeout: 1" + mock_get_available_tools.return_value = {Tool.Calculator.value.ID: ToolDefinition(implementation=MockCalculator)} + + with pytest.raises(HTTPException) as excinfo: + asyncio.run(async_call_tools(chat_history, MockCohereDeployment(), ctx)) + assert str(excinfo.value.status_code) == "500" + assert ( + str(excinfo.value.detail) == "Timeout while calling tools with timeout: 1" ) -def test_async_call_tools_failure_and_success() -> None: +def test_async_call_tools_failure_and_success(mock_get_available_tools) -> None: class MockWebScrape(BaseTool): - NAME = "web_scrape" + ID = "web_scrape" async def call( self, parameters: dict, ctx: Any, **kwargs: Any @@ -120,7 +124,7 @@ async def call( raise Exception("Web scrape failed") class MockCalculator(BaseTool): - NAME = "toolkit_calculator" + ID = "toolkit_calculator" async def call( self, parameters: dict, ctx: Any, **kwargs: Any @@ -131,26 +135,26 @@ async def call( chat_history = [ { "tool_calls": [ - {"name": "web_scrape", "parameters": {"expression": "6*7"}}, - {"name": "toolkit_calculator", "parameters": {"expression": "6*7"}}, + {"name": "web_scrape", "parameters": {"code": "6*7"}}, + {"name": "toolkit_calculator", "parameters": {"code": "6*7"}}, ] } ] - MOCKED_TOOLS = { - ToolName.Calculator: ManagedTool(implementation=MockCalculator), - ToolName.Web_Scrape: ManagedTool(implementation=MockWebScrape), + mock_get_available_tools.return_value = { + Tool.Calculator.value.ID: ToolDefinition(implementation=MockCalculator), + Tool.Web_Scrape.value.ID: ToolDefinition(implementation=MockWebScrape), } - with patch.dict(AVAILABLE_TOOLS, MOCKED_TOOLS): - results = asyncio.run( - async_call_tools(chat_history, MockCohereDeployment(), ctx) - ) - assert { - "call": {"name": "web_scrape", "parameters": {"expression": "6*7"}}, - "outputs": [ - {"error": "Web scrape failed", "status_code": 500, "success": False} - ], - } in results - assert { - "call": {"name": "toolkit_calculator", "parameters": {"expression": "6*7"}}, - "outputs": [{"result": 42}], - } in results + + results = asyncio.run( + async_call_tools(chat_history, MockCohereDeployment(), ctx) + ) + assert { + "call": {"name": "web_scrape", "parameters": {"code": "6*7"}}, + "outputs": [ + {"error": "Web scrape failed", "status_code": 500, "success": False} + ], + } in results + assert { + "call": {"name": "toolkit_calculator", "parameters": {"code": "6*7"}}, + "outputs": [{"result": 42}], + } in results diff --git a/src/backend/tests/unit/config/test_deployments.py b/src/backend/tests/unit/config/test_deployments.py index bb6bac146f..adaa443040 100644 --- a/src/backend/tests/unit/config/test_deployments.py +++ b/src/backend/tests/unit/config/test_deployments.py @@ -1,13 +1,6 @@ -from unittest.mock import Mock +from backend.config.tools import Tool -from backend.config.deployments import ( - get_default_deployment, -) -from backend.tests.unit.model_deployments.mock_deployments.mock_cohere_platform import ( - MockCohereDeployment, -) - -def test_get_default_deployment(mock_available_model_deployments: Mock) -> None: - default_deployment = get_default_deployment() - assert isinstance(default_deployment, MockCohereDeployment) +def test_all_tools_have_id() -> None: + for tool in Tool: + assert tool.value.ID is not None diff --git a/src/backend/tests/unit/config/test_tools.py b/src/backend/tests/unit/config/test_tools.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/backend/tests/unit/configuration.yaml b/src/backend/tests/unit/configuration.yaml index a620a18a20..501c4b531e 100644 --- a/src/backend/tests/unit/configuration.yaml +++ b/src/backend/tests/unit/configuration.yaml @@ -16,7 +16,6 @@ database: redis: url: tools: - enabled_tools: python_interpreter: url: feature_flags: diff --git a/src/backend/tests/unit/crud/test_agent.py b/src/backend/tests/unit/crud/test_agent.py index 18d92b6c15..5da2fafdc8 100644 --- a/src/backend/tests/unit/crud/test_agent.py +++ b/src/backend/tests/unit/crud/test_agent.py @@ -2,7 +2,7 @@ from sqlalchemy.exc import IntegrityError from sqlalchemy.sql.expression import false -from backend.config.tools import ToolName +from backend.config.tools import Tool from backend.crud import agent as agent_crud from backend.database_models.agent import Agent from backend.schemas.agent import AgentVisibility, UpdateAgentRequest @@ -17,7 +17,7 @@ def test_create_agent(session, user): description="test", preamble="test", temperature=0.5, - tools=[ToolName.Wiki_Retriever_LangChain, ToolName.Search_File], + tools=[Tool.Wiki_Retriever_LangChain.value.ID, Tool.Search_File.value.ID], is_private=True, ) @@ -28,7 +28,7 @@ def test_create_agent(session, user): assert agent.description == "test" assert agent.preamble == "test" assert agent.temperature == 0.5 - assert agent.tools == [ToolName.Wiki_Retriever_LangChain, ToolName.Search_File] + assert agent.tools == [Tool.Wiki_Retriever_LangChain.value.ID, Tool.Search_File.value.ID] assert agent.is_private agent = agent_crud.get_agent_by_id(session, agent.id, user.id) @@ -38,7 +38,7 @@ def test_create_agent(session, user): assert agent.description == "test" assert agent.preamble == "test" assert agent.temperature == 0.5 - assert agent.tools == [ToolName.Wiki_Retriever_LangChain, ToolName.Search_File] + assert agent.tools == [Tool.Wiki_Retriever_LangChain.value.ID, Tool.Search_File.value.ID] def test_create_agent_empty_non_required_fields(session, user): @@ -87,7 +87,7 @@ def test_create_agent_duplicate_name_version(session, user): description="test", preamble="test", temperature=0.5, - tools=[ToolName.Wiki_Retriever_LangChain, ToolName.Search_File], + tools=[Tool.Wiki_Retriever_LangChain.value.ID, Tool.Search_File.value.ID], ) with pytest.raises(IntegrityError): @@ -205,7 +205,7 @@ def test_update_agent(session, user): preamble="test", temperature=0.5, user=user, - tools=[ToolName.Wiki_Retriever_LangChain, ToolName.Search_File], + tools=[Tool.Wiki_Retriever_LangChain.value.ID, Tool.Search_File.value.ID], ) new_agent_data = UpdateAgentRequest( @@ -214,7 +214,7 @@ def test_update_agent(session, user): version=2, preamble="new_test", temperature=0.6, - tools=[ToolName.Python_Interpreter, ToolName.Calculator], + tools=[Tool.Python_Interpreter.value.ID, Tool.Calculator.value.ID], ) agent = agent_crud.update_agent(session, agent, new_agent_data, user.id) @@ -223,7 +223,7 @@ def test_update_agent(session, user): assert agent.version == new_agent_data.version assert agent.preamble == new_agent_data.preamble assert agent.temperature == new_agent_data.temperature - assert agent.tools == [ToolName.Python_Interpreter, ToolName.Calculator] + assert agent.tools == [Tool.Python_Interpreter.value.ID, Tool.Calculator.value.ID] def test_delete_agent(session, user): diff --git a/src/backend/tests/unit/crud/test_agent_tool_metadata.py b/src/backend/tests/unit/crud/test_agent_tool_metadata.py index 813e85cbf3..d6ec6a9e64 100644 --- a/src/backend/tests/unit/crud/test_agent_tool_metadata.py +++ b/src/backend/tests/unit/crud/test_agent_tool_metadata.py @@ -1,7 +1,7 @@ import pytest from sqlalchemy.exc import IntegrityError -from backend.config.tools import ToolName +from backend.config.tools import Tool from backend.crud import agent_tool_metadata as agent_tool_metadata_crud from backend.database_models.agent_tool_metadata import AgentToolMetadata from backend.schemas.agent import UpdateAgentToolMetadataRequest @@ -23,13 +23,13 @@ def test_create_agent_tool_metadata(session, user): agent = get_factory("Agent", session).create( - id="1", name="test_agent", tools=[ToolName.Google_Drive], user=user + id="1", name="test_agent", tools=[Tool.Google_Drive.value.ID], user=user ) agent_tool_metadata_data = AgentToolMetadata( user_id=user.id, agent_id=agent.id, - tool_name=ToolName.Google_Drive, + tool_name=Tool.Google_Drive.value.ID, artifacts=[mock_artifact_1], ) agent_tool_metadata = agent_tool_metadata_crud.create_agent_tool_metadata( @@ -37,7 +37,7 @@ def test_create_agent_tool_metadata(session, user): ) assert agent_tool_metadata.user_id == user.id assert agent_tool_metadata.agent_id == agent.id - assert agent_tool_metadata.tool_name == ToolName.Google_Drive + assert agent_tool_metadata.tool_name == Tool.Google_Drive.value.ID assert agent_tool_metadata.artifacts == [mock_artifact_1] assert agent_tool_metadata.artifacts[0]["type"] == "document" @@ -46,7 +46,7 @@ def test_create_agent_tool_metadata(session, user): ) assert agent_tool_metadata.user_id == user.id assert agent_tool_metadata.agent_id == agent.id - assert agent_tool_metadata.tool_name == ToolName.Google_Drive + assert agent_tool_metadata.tool_name == Tool.Google_Drive.value.ID assert agent_tool_metadata.artifacts == [mock_artifact_1] assert agent_tool_metadata.artifacts[0]["type"] == "document" @@ -54,7 +54,7 @@ def test_create_agent_tool_metadata(session, user): def test_create_agent_missing_agent_id(session, user): agent_tool_metadata_data = AgentToolMetadata( user_id=user.id, - tool_name=ToolName.Google_Drive, + tool_name=Tool.Google_Drive.value.ID, artifacts=[mock_artifact_1], ) with pytest.raises(IntegrityError): @@ -65,7 +65,7 @@ def test_create_agent_missing_agent_id(session, user): def test_create_agent_missing_tool_name(session, user): agent = get_factory("Agent", session).create( - id="1", name="test_agent", tools=[ToolName.Google_Drive], user=user + id="1", name="test_agent", tools=[Tool.Google_Drive.value.ID], user=user ) agent_tool_metadata_data = AgentToolMetadata( @@ -81,12 +81,12 @@ def test_create_agent_missing_tool_name(session, user): def test_create_agent_missing_user_id(session, user): agent = get_factory("Agent", session).create( - id="1", name="test_agent", tools=[ToolName.Google_Drive], user=user + id="1", name="test_agent", tools=[Tool.Google_Drive.value.ID], user=user ) agent_tool_metadata_data = AgentToolMetadata( agent_id=agent.id, - tool_name=ToolName.Google_Drive, + tool_name=Tool.Google_Drive.value.ID, artifacts=[mock_artifact_1], user_id="123", ) @@ -101,7 +101,7 @@ def test_update_agent_tool_metadata(session, user): original_agent_tool_metadata = get_factory("AgentToolMetadata", session).create( user_id=user.id, agent_id=agent.id, - tool_name=ToolName.Google_Drive, + tool_name=Tool.Google_Drive.value.ID, artifacts=[mock_artifact_1], ) @@ -123,7 +123,7 @@ def test_get_agent_tool_metadata_by_id(session, user): agent_tool_metadata = get_factory("AgentToolMetadata", session).create( user_id=user.id, agent_id=agent.id, - tool_name=ToolName.Google_Drive, + tool_name=Tool.Google_Drive.value.ID, artifacts=[mock_artifact_1, mock_artifact_2], ) agent_tool_metadata = agent_tool_metadata_crud.get_agent_tool_metadata_by_id( @@ -131,7 +131,7 @@ def test_get_agent_tool_metadata_by_id(session, user): ) assert agent_tool_metadata.user_id == user.id assert agent_tool_metadata.agent_id == agent.id - assert agent_tool_metadata.tool_name == ToolName.Google_Drive + assert agent_tool_metadata.tool_name == Tool.Google_Drive.value.ID assert agent_tool_metadata.artifacts == [mock_artifact_1, mock_artifact_2] @@ -143,18 +143,20 @@ def test_get_all_agent_tool_metadata_by_agent_id(session, user): _ = get_factory("AgentToolMetadata", session).create( user_id=user.id, agent_id=agent1.id, - tool_name=ToolName.Google_Drive, + tool_name=Tool.Google_Drive.value.ID, artifacts=[mock_artifact_1, mock_artifact_2], ) + # Constraint was added preventing multiple entries for the same user + agent + tool so fixing to change the tool used i = 0 - for tool in ToolName: + for tool in Tool: i += 1 + _ = get_factory("Agent", session).create(user_id=user.id) _ = get_factory("AgentToolMetadata", session).create( id=f"{i}", - tool_name=tool.value, + tool_name=tool.value.ID, artifacts=[mock_artifact_1, mock_artifact_2], user_id=user.id, agent_id=agent2.id, @@ -165,7 +167,7 @@ def test_get_all_agent_tool_metadata_by_agent_id(session, user): session, agent_id=agent2.id ) ) - assert len(all_agent_tool_metadata) == len(ToolName) + assert len(all_agent_tool_metadata) == len(Tool) def test_delete_agent_tool_metadata_by_id(session, user): @@ -173,7 +175,7 @@ def test_delete_agent_tool_metadata_by_id(session, user): agent_tool_metadata = get_factory("AgentToolMetadata", session).create( user_id=user.id, agent_id=agent.id, - tool_name=ToolName.Google_Drive, + tool_name=Tool.Google_Drive.value.ID, artifacts=[mock_artifact_1, mock_artifact_2], ) diff --git a/src/backend/tests/unit/crud/test_tool_auth.py b/src/backend/tests/unit/crud/test_tool_auth.py index 251adce406..c6d2772858 100644 --- a/src/backend/tests/unit/crud/test_tool_auth.py +++ b/src/backend/tests/unit/crud/test_tool_auth.py @@ -1,6 +1,6 @@ from datetime import datetime -from backend.config.tools import ToolName +from backend.config.tools import Tool from backend.crud import tool_auth as tool_auth_crud from backend.database_models.tool_auth import ToolAuth from backend.tests.unit.factories import get_factory @@ -9,7 +9,7 @@ def test_create_tool_auth(session, user): tool_auth_data = ToolAuth( user_id=user.id, - tool_id=ToolName.Google_Drive, + tool_id=Tool.Google_Drive.value.ID, token_type="Bearer", encrypted_access_token=bytes(b"foobar"), encrypted_refresh_token=bytes(b"foobar"), @@ -34,7 +34,7 @@ def test_create_tool_auth(session, user): def test_delete_tool_auth_by_tool_id(session, user): tool_auth = get_factory("ToolAuth", session).create( user_id=user.id, - tool_id=ToolName.Google_Drive, + tool_id=Tool.Google_Drive.value.ID, token_type="Bearer", encrypted_access_token=bytes(b"foobar"), encrypted_refresh_token=bytes(b"foobar"), diff --git a/src/backend/tests/unit/factories/agent.py b/src/backend/tests/unit/factories/agent.py index 7b50336e26..0b04348157 100644 --- a/src/backend/tests/unit/factories/agent.py +++ b/src/backend/tests/unit/factories/agent.py @@ -1,6 +1,6 @@ import factory -from backend.config.tools import ToolName +from backend.config.tools import Tool from backend.database_models.agent import Agent from backend.tests.unit.factories.base import BaseFactory from backend.tests.unit.factories.user import UserFactory @@ -25,12 +25,12 @@ class Meta: factory.Faker( "random_element", elements=[ - ToolName.Wiki_Retriever_LangChain, - ToolName.Search_File, - ToolName.Read_File, - ToolName.Python_Interpreter, - ToolName.Calculator, - ToolName.Tavily_Web_Search, + Tool.Wiki_Retriever_LangChain.value.ID, + Tool.Search_File.value.ID, + Tool.Read_File.value.ID, + Tool.Python_Interpreter.value.ID, + Tool.Calculator.value.ID, + Tool.Tavily_Web_Search.value.ID, ], ) ] diff --git a/src/backend/tests/unit/factories/agent_tool_metadata.py b/src/backend/tests/unit/factories/agent_tool_metadata.py index d8c9151542..0bb8520528 100644 --- a/src/backend/tests/unit/factories/agent_tool_metadata.py +++ b/src/backend/tests/unit/factories/agent_tool_metadata.py @@ -1,6 +1,6 @@ import factory -from backend.config.tools import ToolName +from backend.config.tools import Tool from backend.database_models.agent_tool_metadata import AgentToolMetadata from .base import BaseFactory @@ -16,13 +16,13 @@ class Meta: factory.Faker( "random_element", elements=[ - ToolName.Wiki_Retriever_LangChain, - ToolName.Search_File, - ToolName.Read_File, - ToolName.Python_Interpreter, - ToolName.Calculator, - ToolName.Tavily_Web_Search, - ToolName.Google_Drive, + Tool.Wiki_Retriever_LangChain.value.ID, + Tool.Search_File.value.ID, + Tool.Read_File.value.ID, + Tool.Python_Interpreter.value.ID, + Tool.Calculator.value.ID, + Tool.Tavily_Web_Search.value.ID, + Tool.Google_Drive.value.ID, ], ) ] diff --git a/src/backend/tests/unit/factories/tool_auth.py b/src/backend/tests/unit/factories/tool_auth.py index 8393eae259..af6f198288 100644 --- a/src/backend/tests/unit/factories/tool_auth.py +++ b/src/backend/tests/unit/factories/tool_auth.py @@ -2,7 +2,7 @@ import factory -from backend.config.tools import ToolName +from backend.config.tools import Tool from backend.database_models.tool_auth import ToolAuth from .base import BaseFactory @@ -13,7 +13,7 @@ class Meta: model = ToolAuth user_id = factory.Faker("uuid4") - tool_id = ToolName.Google_Drive + tool_id = Tool.Google_Drive.value.ID token_type = "Bearer" encrypted_access_token = bytes(b"foobar") encrypted_refresh_token = bytes(b"foobar") diff --git a/src/backend/tests/unit/routers/test_agent.py b/src/backend/tests/unit/routers/test_agent.py index b047318a82..725c2a752e 100644 --- a/src/backend/tests/unit/routers/test_agent.py +++ b/src/backend/tests/unit/routers/test_agent.py @@ -5,7 +5,7 @@ from sqlalchemy.orm import Session from backend.config.deployments import ModelDeploymentName -from backend.config.tools import ToolName +from backend.config.tools import Tool from backend.crud import agent as agent_crud from backend.crud import deployment as deployment_crud from backend.database_models.agent import Agent @@ -135,14 +135,14 @@ def test_create_agent_invalid_tool( "name": "test agent", "model": "command-r-plus", "deployment": ModelDeploymentName.CoherePlatform, - "tools": [ToolName.Calculator, "not a real tool"], + "tools": [Tool.Calculator.value.ID, "fake_tool"], } response = session_client.post( "/v1/agents", json=request_json, headers={"User-Id": user.id} ) assert response.status_code == 404 - assert response.json() == {"detail": "Tool not a real tool not found."} + assert response.json() == {"detail": "Tool fake_tool not found."} def test_create_existing_agent( @@ -372,7 +372,7 @@ def test_get_agent(session_client: TestClient, session: Session, user) -> None: agent_tool_metadata = get_factory("AgentToolMetadata", session).create( user_id=user.id, agent_id=agent.id, - tool_name=ToolName.Google_Drive, + tool_name=Tool.Google_Drive.value.ID, artifacts=[ { "name": "/folder1", @@ -393,7 +393,7 @@ def test_get_agent(session_client: TestClient, session: Session, user) -> None: assert response.status_code == 200 response_agent = response.json() assert response_agent["name"] == agent.name - assert response_agent["tools_metadata"][0]["tool_name"] == ToolName.Google_Drive + assert response_agent["tools_metadata"][0]["tool_name"] == Tool.Google_Drive.value.ID assert ( response_agent["tools_metadata"][0]["artifacts"] == agent_tool_metadata.artifacts @@ -498,13 +498,13 @@ def test_partial_update_agent(session_client: TestClient, session: Session) -> N description="test description", preamble="test preamble", temperature=0.5, - tools=[ToolName.Calculator], + tools=[Tool.Calculator.value.ID], user=user, ) request_json = { "name": "updated name", - "tools": [ToolName.Search_File, ToolName.Read_File], + "tools": [Tool.Search_File.value.ID, Tool.Read_File.value.ID], } response = session_client.put( @@ -519,7 +519,7 @@ def test_partial_update_agent(session_client: TestClient, session: Session) -> N assert updated_agent["description"] == "test description" assert updated_agent["preamble"] == "test preamble" assert updated_agent["temperature"] == 0.5 - assert updated_agent["tools"] == [ToolName.Search_File, ToolName.Read_File] + assert updated_agent["tools"] == [Tool.Search_File.value.ID, Tool.Read_File.value.ID] def test_update_agent_with_tool_metadata( @@ -537,7 +537,7 @@ def test_update_agent_with_tool_metadata( agent_tool_metadata = get_factory("AgentToolMetadata", session).create( user_id=user.id, agent_id=agent.id, - tool_name=ToolName.Google_Drive, + tool_name=Tool.Google_Drive.value.ID, artifacts=[ { "url": "test", @@ -601,7 +601,7 @@ def test_update_agent_with_tool_metadata_and_new_tool_metadata( agent_tool_metadata = get_factory("AgentToolMetadata", session).create( user_id=user.id, agent_id=agent.id, - tool_name=ToolName.Google_Drive, + tool_name=Tool.Google_Drive.value.ID, artifacts=[ { "url": "test", @@ -681,7 +681,7 @@ def test_update_agent_remove_existing_tool_metadata( get_factory("AgentToolMetadata", session).create( user_id=user.id, agent_id=agent.id, - tool_name=ToolName.Google_Drive, + tool_name=Tool.Google_Drive.value.ID, artifacts=[ { "url": "test", @@ -809,7 +809,7 @@ def test_update_agent_invalid_tool( request_json = { "model": "not a real model", "deployment": "not a real deployment", - "tools": [ToolName.Calculator, "not a real tool"], + "tools": [Tool.Calculator.value.ID, "not a real tool"], } response = session_client.put( @@ -1036,7 +1036,7 @@ def test_create_agent_tool_metadata( ) -> None: agent = get_factory("Agent", session).create(user=user) request_json = { - "tool_name": ToolName.Google_Drive, + "tool_name": Tool.Google_Drive.value.ID, "artifacts": [ { "name": "/folder1", @@ -1065,7 +1065,7 @@ def test_create_agent_tool_metadata( agent_tool_metadata = session.get( AgentToolMetadata, response_agent_tool_metadata["id"] ) - assert agent_tool_metadata.tool_name == ToolName.Google_Drive + assert agent_tool_metadata.tool_name == Tool.Google_Drive.value.ID assert agent_tool_metadata.artifacts == [ { "name": "/folder1", @@ -1087,7 +1087,7 @@ def test_update_agent_tool_metadata( agent_tool_metadata = get_factory("AgentToolMetadata", session).create( user_id=user.id, agent_id=agent.id, - tool_name=ToolName.Google_Drive, + tool_name=Tool.Google_Drive.value.ID, artifacts=[ { "name": "/folder1", @@ -1148,7 +1148,7 @@ def test_get_agent_tool_metadata( agent_tool_metadata_1 = get_factory("AgentToolMetadata", session).create( user_id=user.id, agent_id=agent.id, - tool_name=ToolName.Google_Drive, + tool_name=Tool.Google_Drive.value.ID, artifacts=[ {"name": "/folder", "ids": ["folder1", "folder2"], "type": "folder_ids"} ], @@ -1156,7 +1156,7 @@ def test_get_agent_tool_metadata( agent_tool_metadata_2 = get_factory("AgentToolMetadata", session).create( user_id=user.id, agent_id=agent.id, - tool_name=ToolName.Search_File, + tool_name=Tool.Search_File.value.ID, artifacts=[{"name": "file.txt", "ids": ["file1", "file2"], "type": "file_ids"}], ) @@ -1182,7 +1182,7 @@ def test_delete_agent_tool_metadata( agent_tool_metadata = get_factory("AgentToolMetadata", session).create( user_id=user.id, agent_id=agent.id, - tool_name=ToolName.Google_Drive, + tool_name=Tool.Google_Drive.value.ID, artifacts=[ { "name": "/folder1", diff --git a/src/backend/tests/unit/routers/test_chat.py b/src/backend/tests/unit/routers/test_chat.py index 7e8d06ea2e..559865f040 100644 --- a/src/backend/tests/unit/routers/test_chat.py +++ b/src/backend/tests/unit/routers/test_chat.py @@ -13,7 +13,7 @@ from backend.database_models.conversation import Conversation from backend.database_models.message import Message, MessageAgent from backend.database_models.user import User -from backend.schemas.tool import Category +from backend.schemas.tool import ToolCategory from backend.tests.unit.factories import get_factory is_cohere_env_set = ( @@ -375,36 +375,11 @@ def test_streaming_fail_chat_missing_message( } -@pytest.mark.skipif(not is_cohere_env_set, reason="Cohere API key not set") -def test_streaming_chat_with_custom_tools(session_client_chat, session_chat, user): - response = session_client_chat.post( - "/v1/chat-stream", - json={ - "message": "Give me a number", - "tools": [ - { - "name": "random_number_generator", - "description": "generate a random number", - } - ], - }, - headers={ - "User-Id": user.id, - "Deployment-Name": ModelDeploymentName.CoherePlatform, - }, - ) - - assert response.status_code == 200 - validate_chat_streaming_response( - response, user, session_chat, session_client_chat, 0, is_custom_tools=True - ) - - @pytest.mark.skipif(not is_cohere_env_set, reason="Cohere API key not set") def test_streaming_chat_with_managed_tools(session_client_chat, session_chat, user): tools = session_client_chat.get("/v1/tools", headers={"User-Id": user.id}).json() assert len(tools) > 0 - tool = [t for t in tools if t["is_visible"] and t["category"] != Category.Function][ + tool = [t for t in tools if t["is_visible"] and t["category"] != ToolCategory.Function][ 0 ].get("name") @@ -446,7 +421,7 @@ def test_streaming_chat_with_managed_and_custom_tools( ): tools = session_client_chat.get("/v1/tools", headers={"User-Id": user.id}).json() assert len(tools) > 0 - tool = [t for t in tools if t["is_visible"] and t["category"] != Category.Function][ + tool = [t for t in tools if t["is_visible"] and t["category"] != ToolCategory.Function][ 0 ].get("name") @@ -806,7 +781,7 @@ def test_non_streaming_chat( def test_non_streaming_chat_with_managed_tools(session_client_chat, session_chat, user): tools = session_client_chat.get("/v1/tools", headers={"User-Id": user.id}).json() assert len(tools) > 0 - tool = [t for t in tools if t["is_visible"] and t["category"] != Category.Function][ + tool = [t for t in tools if t["is_visible"] and t["category"] != ToolCategory.Function][ 0 ].get("name") @@ -831,7 +806,7 @@ def test_non_streaming_chat_with_managed_and_custom_tools( ): tools = session_client_chat.get("/v1/tools", headers={"User-Id": user.id}).json() assert len(tools) > 0 - tool = [t for t in tools if t["is_visible"] and t["category"] != Category.Function][ + tool = [t for t in tools if t["is_visible"] and t["category"] != ToolCategory.Function][ 0 ].get("name") @@ -856,30 +831,6 @@ def test_non_streaming_chat_with_managed_and_custom_tools( assert response.status_code == 400 assert response.json() == {"detail": "Cannot mix both managed and custom tools"} - -@pytest.mark.skipif(not is_cohere_env_set, reason="Cohere API key not set") -def test_non_streaming_chat_with_custom_tools(session_client_chat, session_chat, user): - response = session_client_chat.post( - "/v1/chat", - json={ - "message": "Give me a number", - "tools": [ - { - "name": "random_number_generator", - "description": "generate a random number", - } - ], - }, - headers={ - "User-Id": user.id, - "Deployment-Name": ModelDeploymentName.CoherePlatform, - }, - ) - - assert response.status_code == 200 - assert len(response.json()["tool_calls"]) == 1 - - @pytest.mark.skipif(not is_cohere_env_set, reason="Cohere API key not set") def test_non_streaming_chat_with_search_queries_only( session_client_chat: TestClient, session_chat: Session, user: User diff --git a/src/backend/tests/unit/routers/test_tool.py b/src/backend/tests/unit/routers/test_tool.py index 8636bb1181..943dd5bb89 100644 --- a/src/backend/tests/unit/routers/test_tool.py +++ b/src/backend/tests/unit/routers/test_tool.py @@ -1,7 +1,7 @@ from fastapi.testclient import TestClient from sqlalchemy.orm import Session -from backend.config.tools import AVAILABLE_TOOLS, ToolName +from backend.config.tools import Tool, get_available_tools from backend.schemas.user import User from backend.tests.unit.factories import get_factory @@ -9,18 +9,14 @@ def test_list_tools(session_client: TestClient, session: Session) -> None: response = session_client.get("/v1/tools") assert response.status_code == 200 + available_tools = get_available_tools() for tool in response.json(): - assert tool["name"] in AVAILABLE_TOOLS.keys() - - # get tool that has the same name as the tool in the response - tool_definition = AVAILABLE_TOOLS[tool["name"]] - - assert tool["kwargs"] == tool_definition.kwargs - assert tool["is_visible"] == tool_definition.is_visible - assert tool["is_available"] == tool_definition.is_available - assert tool["error_message"] == tool_definition.error_message - assert tool["category"] == tool_definition.category - assert tool["description"] == tool_definition.description + assert tool["name"] in available_tools.keys() + assert tool["kwargs"] is not None + assert tool["is_visible"] is not None + assert tool["is_available"] is not None + assert tool["category"] is not None + assert tool["description"] is not None def test_list_tools_error_message_none_if_available(client: TestClient) -> None: @@ -35,7 +31,7 @@ def test_list_tools_with_agent( session_client: TestClient, session: Session, user: User ) -> None: agent = get_factory("Agent", session).create( - name="test agent", tools=[ToolName.Wiki_Retriever_LangChain], user=user + name="test agent", tools=[Tool.Wiki_Retriever_LangChain.value.ID], user=user ) response = session_client.get("/v1/tools", params={"agent_id": agent.id}) @@ -43,10 +39,10 @@ def test_list_tools_with_agent( assert len(response.json()) == 1 tool = response.json()[0] - assert tool["name"] == ToolName.Wiki_Retriever_LangChain + assert tool["name"] == Tool.Wiki_Retriever_LangChain.value.ID # get tool that has the same name as the tool in the response - tool_definition = AVAILABLE_TOOLS[tool["name"]] + tool_definition = get_available_tools()[tool["name"]] assert tool["kwargs"] == tool_definition.kwargs assert tool["is_visible"] == tool_definition.is_visible diff --git a/src/backend/tools/base.py b/src/backend/tools/base.py index aa66fcd2c9..203a9328dd 100644 --- a/src/backend/tools/base.py +++ b/src/backend/tools/base.py @@ -1,5 +1,5 @@ import datetime -from abc import abstractmethod +from abc import ABC, abstractmethod from typing import Any, Dict, List from fastapi import Request @@ -8,31 +8,28 @@ from backend.crud import tool_auth as tool_auth_crud from backend.database_models.database import DBSessionDep from backend.database_models.tool_auth import ToolAuth +from backend.schemas.tool import ToolDefinition from backend.services.logger.utils import LoggerFactory logger = LoggerFactory().get_logger() -class BaseTool: +class BaseTool(): """ Abstract base class for all Tools. Attributes: - NAME (str): The name of the tool. + ID (str): The name of the tool. """ - - NAME = None + ID = None def __init__(self, *args, **kwargs): self._post_init_check() - def _post_init_check(self): - if any( - [ - self.NAME is None, - ] - ): - raise ValueError(f"{self.__name__} must have NAME attribute defined.") + @classmethod + def _post_init_check(cls): + if cls.ID is None: + raise ValueError(f"{cls.__name__} must have ID attribute defined.") @classmethod @abstractmethod @@ -40,6 +37,16 @@ def is_available(cls) -> bool: ... @classmethod @abstractmethod + def get_tool_definition(cls) -> ToolDefinition: ... + + @classmethod + def generate_error_message(cls) -> str | None: + if cls.is_available(): + return None + + return f"{cls.__name__} is not available. Please make sure all required config variables are set." + + @classmethod def _handle_tool_specific_errors(cls, error: Exception, **kwargs: Any) -> None: pass @@ -49,7 +56,7 @@ async def call( ) -> List[Dict[str, Any]]: ... -class BaseToolAuthentication: +class BaseToolAuthentication(ABC): """ Abstract base class for Tool Authentication. """ @@ -61,12 +68,13 @@ def __init__(self, *args, **kwargs): self._post_init_check() - def _post_init_check(self): + @classmethod + def _post_init_check(cls): if any( [ - self.BACKEND_HOST is None, - self.FRONTEND_HOST is None, - self.AUTH_SECRET_KEY is None, + cls.BACKEND_HOST is None, + cls.FRONTEND_HOST is None, + cls.AUTH_SECRET_KEY is None, ] ): raise ValueError( diff --git a/src/backend/tools/brave_search/tool.py b/src/backend/tools/brave_search/tool.py index 85899b6a9d..0fcd9bf207 100644 --- a/src/backend/tools/brave_search/tool.py +++ b/src/backend/tools/brave_search/tool.py @@ -4,13 +4,14 @@ from backend.database_models.database import DBSessionDep from backend.model_deployments.base import BaseDeployment from backend.schemas.agent import AgentToolMetadataArtifactsType +from backend.schemas.tool import ToolCategory, ToolDefinition from backend.tools.base import BaseTool from backend.tools.brave_search.client import BraveClient from backend.tools.utils.mixins import WebSearchFilteringMixin class BraveWebSearch(BaseTool, WebSearchFilteringMixin): - NAME = "brave_web_search" + ID = "brave_web_search" BRAVE_API_KEY = Settings().get('tools.brave_web_search.api_key') def __init__(self): @@ -21,6 +22,29 @@ def __init__(self): def is_available(cls) -> bool: return cls.BRAVE_API_KEY is not None + @classmethod + def get_tool_definition(cls) -> ToolDefinition: + return ToolDefinition( + name=cls.ID, + display_name="Brave Web Search", + implementation=cls, + parameter_definitions={ + "query": { + "description": "Query for retrieval.", + "type": "str", + "required": True, + } + }, + is_visible=False, + is_available=cls.is_available(), + error_message=cls.generate_error_message(), + category=ToolCategory.WebSearch, + description=( + "Returns a list of relevant document snippets for a textual query retrieved " + "from the internet using Brave Search." + ), + ) + async def call( self, parameters: dict, ctx: Any, session: DBSessionDep, **kwargs: Any ) -> List[Dict[str, Any]]: diff --git a/src/backend/tools/calculator.py b/src/backend/tools/calculator.py index f4566e2875..3b96859663 100644 --- a/src/backend/tools/calculator.py +++ b/src/backend/tools/calculator.py @@ -2,6 +2,7 @@ from py_expression_eval import Parser +from backend.schemas.tool import ToolCategory, ToolDefinition from backend.tools.base import BaseTool @@ -10,12 +11,32 @@ class Calculator(BaseTool): Function Tool that evaluates mathematical expressions. """ - NAME = "toolkit_calculator" + ID = "toolkit_calculator" @classmethod def is_available(cls) -> bool: return True + @classmethod + def get_tool_definition(cls) -> ToolDefinition: + return ToolDefinition( + name=cls.ID, + display_name="Calculator", + implementation=Calculator, + parameter_definitions={ + "code": { + "description": "The expression for the calculator to evaluate, it should be a valid mathematical expression.", + "type": "str", + "required": True, + } + }, + is_visible=False, + is_available=Calculator.is_available(), + category=ToolCategory.Function, + error_message=cls.generate_error_message(), + description="A powerful multi-purpose calculator capable of a wide array of math calculations.", + ) + async def call( self, parameters: dict, ctx: Any, **kwargs: Any ) -> List[Dict[str, Any]]: diff --git a/src/backend/tools/files.py b/src/backend/tools/files.py index 3b72b662ad..146a741c0e 100644 --- a/src/backend/tools/files.py +++ b/src/backend/tools/files.py @@ -2,6 +2,7 @@ from typing import Any, Dict, List import backend.crud.file as file_crud +from backend.schemas.tool import ToolCategory, ToolDefinition from backend.tools.base import BaseTool @@ -13,7 +14,7 @@ class ReadFileTool(BaseTool): Tool to read a file from the file system. """ - NAME = "read_file" + ID = "read_file" MAX_NUM_CHUNKS = 10 SEARCH_LIMIT = 5 @@ -24,6 +25,33 @@ def __init__(self): def is_available(cls) -> bool: return True + @classmethod + def get_tool_definition(cls) -> ToolDefinition: + return ToolDefinition( + name=cls.ID, + display_name="Read Document", + implementation=cls, + parameter_definitions={ + "file": { + "description": "A file represented as a tuple (filename, file ID) to read over", + "type": "tuple[str, str]", + "required": True, + } + }, + is_visible=True, + is_available=cls.is_available(), + error_message=cls.generate_error_message(), + category=ToolCategory.FileLoader, + description="Returns the chunked textual contents of an uploaded file.", + ) + + def get_info(cls) -> ToolDefinition: + return ToolDefinition( + display_name="Calculator", + description="A powerful multi-purpose calculator capable of a wide array of math calculations.", + error_message=cls.generate_error_message(), + ) + async def call(self, parameters: dict, **kwargs: Any) -> List[Dict[str, Any]]: file = parameters.get("file") @@ -50,7 +78,8 @@ class SearchFileTool(BaseTool): Tool to query a list of files. """ - NAME = "search_file" + ID = "search_file" + DISPLAY_NAME = "Search Files" MAX_NUM_CHUNKS = 10 SEARCH_LIMIT = 5 @@ -61,6 +90,31 @@ def __init__(self): def is_available(cls) -> bool: return True + @classmethod + def get_tool_definition(cls) -> ToolDefinition: + return ToolDefinition( + name=cls.ID, + display_name="Search File", + implementation=cls, + parameter_definitions={ + "search_query": { + "description": "Textual search query to search over the file's content for", + "type": "str", + "required": True, + }, + "files": { + "description": "A list of files represented as tuples of (filename, file ID) to search over", + "type": "list[tuple[str, str]]", + "required": True, + }, + }, + is_visible=True, + is_available=cls.is_available(), + error_message=cls.generate_error_message(), + category=ToolCategory.FileLoader, + description="Searches across one or more attached files based on a textual search query.", + ) + async def call( self, parameters: dict, ctx: Any, **kwargs: Any ) -> List[Dict[str, Any]]: diff --git a/src/backend/tools/google_drive/tool.py b/src/backend/tools/google_drive/tool.py index 3691b75b56..a8c732223b 100644 --- a/src/backend/tools/google_drive/tool.py +++ b/src/backend/tools/google_drive/tool.py @@ -4,8 +4,10 @@ from backend.config.settings import Settings from backend.crud import tool_auth as tool_auth_crud +from backend.schemas.tool import ToolCategory, ToolDefinition from backend.services.logger.utils import LoggerFactory from backend.tools.base import BaseTool +from backend.tools.google_drive.auth import GoogleDriveAuth from backend.tools.google_drive.constants import GOOGLE_DRIVE_TOOL_ID, SEARCH_LIMIT from backend.tools.google_drive.utils import ( extract_export_link, @@ -23,9 +25,7 @@ class GoogleDrive(BaseTool): """ Tool that searches Google Drive """ - - NAME = GOOGLE_DRIVE_TOOL_ID - + ID = GOOGLE_DRIVE_TOOL_ID CLIENT_ID = Settings().get('tools.google_drive.client_id') CLIENT_SECRET = Settings().get('tools.google_drive.client_secret') @@ -33,6 +33,27 @@ class GoogleDrive(BaseTool): def is_available(cls) -> bool: return cls.CLIENT_ID is not None and cls.CLIENT_SECRET is not None + @classmethod + def get_tool_definition(cls) -> ToolDefinition: + return ToolDefinition( + name=cls.ID, + display_name="Google Drive", + implementation=cls, + parameter_definitions={ + "query": { + "description": "Query to search Google Drive documents with.", + "type": "str", + "required": True, + } + }, + is_visible=True, + is_available=GoogleDrive.is_available(), + auth_implementation=GoogleDriveAuth, + error_message=cls.generate_error_message(), + category=ToolCategory.DataLoader, + description="Returns a list of relevant document snippets from the user's Google drive.", + ) + def _handle_tool_specific_errors(self, error: Exception, **kwargs: Any): message = "[Google Drive] Tool Error: {}".format(str(error)) diff --git a/src/backend/tools/google_search.py b/src/backend/tools/google_search.py index c8df4216e6..cc2ddc40cd 100644 --- a/src/backend/tools/google_search.py +++ b/src/backend/tools/google_search.py @@ -5,12 +5,13 @@ from backend.config.settings import Settings from backend.database_models.database import DBSessionDep from backend.schemas.agent import AgentToolMetadataArtifactsType +from backend.schemas.tool import ToolCategory, ToolDefinition from backend.tools.base import BaseTool from backend.tools.utils.mixins import WebSearchFilteringMixin class GoogleWebSearch(BaseTool, WebSearchFilteringMixin): - NAME = "google_web_search" + ID = "google_web_search" API_KEY = Settings().get('tools.google_web_search.api_key') CSE_ID = Settings().get('tools.google_web_search.cse_id') @@ -21,6 +22,26 @@ def __init__(self): def is_available(cls) -> bool: return bool(cls.API_KEY) and bool(cls.CSE_ID) + @classmethod + def get_tool_definition(cls) -> ToolDefinition: + return ToolDefinition( + name=cls.ID, + display_name="Google Web Search", + implementation=cls, + parameter_definitions={ + "query": { + "description": "A search query for the Google search engine.", + "type": "str", + "required": True, + } + }, + is_visible=False, + is_available=cls.is_available(), + error_message=cls.generate_error_message(), + category=ToolCategory.WebSearch, + description="Returns relevant results by performing a Google web search.", + ) + async def call( self, parameters: dict, ctx: Any, session: DBSessionDep, **kwargs: Any ) -> List[Dict[str, Any]]: diff --git a/src/backend/tools/hybrid_search.py b/src/backend/tools/hybrid_search.py index 8af1e98cc3..e6bf4973ec 100644 --- a/src/backend/tools/hybrid_search.py +++ b/src/backend/tools/hybrid_search.py @@ -6,6 +6,7 @@ from backend.database_models.database import DBSessionDep from backend.model_deployments.base import BaseDeployment from backend.schemas.agent import AgentToolMetadataArtifactsType +from backend.schemas.tool import ToolCategory, ToolDefinition from backend.tools.base import BaseTool from backend.tools.brave_search.tool import BraveWebSearch from backend.tools.google_search import GoogleWebSearch @@ -15,7 +16,7 @@ class HybridWebSearch(BaseTool, WebSearchFilteringMixin): - NAME = "hybrid_web_search" + ID = "hybrid_web_search" POST_RERANK_MAX_RESULTS = 6 AVAILABLE_WEB_SEARCH_TOOLS = [TavilyWebSearch, GoogleWebSearch, BraveWebSearch] ENABLED_WEB_SEARCH_TOOLS = Settings().get('tools.hybrid_web_search.enabled_web_searches') @@ -38,13 +39,36 @@ def is_available(cls) -> bool: # False if empty, True otherwise return bool(available_searches) + @classmethod + def get_tool_definition(cls) -> ToolDefinition: + return ToolDefinition( + name=cls.ID, + display_name="Hybrid Web Search", + implementation=cls, + parameter_definitions={ + "query": { + "description": "Query for retrieval.", + "type": "str", + "required": True, + } + }, + is_visible=True, + is_available=cls.is_available(), + error_message=cls.generate_error_message(), + category=ToolCategory.WebSearch, + description=( + "Returns a list of relevant document snippets for a textual query " + "retrieved from the internet using a mix of any existing Web Search tools." + ) + ) + @classmethod def get_available_search_tools(cls): available_search_tools = [] for search_name in cls.ENABLED_WEB_SEARCH_TOOLS: for search_tool in cls.AVAILABLE_WEB_SEARCH_TOOLS: - if search_name == search_tool.NAME and search_tool.is_available(): + if search_name == search_tool.ID and search_tool.is_available(): available_search_tools.append(search_tool) return available_search_tools diff --git a/src/backend/tools/lang_chain.py b/src/backend/tools/lang_chain.py index 9dd64f8eec..345d5f3d79 100644 --- a/src/backend/tools/lang_chain.py +++ b/src/backend/tools/lang_chain.py @@ -7,6 +7,7 @@ from langchain_community.vectorstores import Chroma from backend.config.settings import Settings +from backend.schemas.tool import ToolCategory, ToolDefinition from backend.tools.base import BaseTool """ @@ -22,8 +23,7 @@ class LangChainWikiRetriever(BaseTool): This class retrieves documents from Wikipedia using the langchain package. This requires wikipedia package to be installed. """ - - NAME = "wikipedia" + ID = "wikipedia" def __init__(self, chunk_size: int = 300, chunk_overlap: int = 0): self.chunk_size = chunk_size @@ -33,6 +33,27 @@ def __init__(self, chunk_size: int = 300, chunk_overlap: int = 0): def is_available(cls) -> bool: return True + @classmethod + def get_tool_definition(cls) -> ToolDefinition: + return ToolDefinition( + name=cls.ID, + display_name="Wikipedia", + implementation=cls, + parameter_definitions={ + "query": { + "description": "Query for retrieval.", + "type": "str", + "required": True, + } + }, + kwargs={"chunk_size": 300, "chunk_overlap": 0}, + is_visible=True, + is_available=cls.is_available(), + error_message=cls.generate_error_message(), + category=ToolCategory.DataLoader, + description="Retrieves documents from Wikipedia.", + ) + async def call( self, parameters: dict, ctx: Any, **kwargs: Any ) -> List[Dict[str, Any]]: @@ -58,8 +79,7 @@ class LangChainVectorDBRetriever(BaseTool): """ This class retrieves documents from a vector database using the langchain package. """ - - NAME = "vector_retriever" + ID = "vector_retriever" COHERE_API_KEY = Settings().get('deployments.cohere_platform.api_key') def __init__(self, filepath: str): diff --git a/src/backend/tools/python_interpreter.py b/src/backend/tools/python_interpreter.py index 3ebc664124..426844ab48 100644 --- a/src/backend/tools/python_interpreter.py +++ b/src/backend/tools/python_interpreter.py @@ -5,6 +5,7 @@ from dotenv import load_dotenv from backend.config.settings import Settings +from backend.schemas.tool import ToolCategory, ToolDefinition from backend.tools.base import BaseTool load_dotenv() @@ -16,13 +17,41 @@ class PythonInterpreter(BaseTool): It requires a URL at which the interpreter lives """ - NAME = "toolkit_python_interpreter" + ID = "toolkit_python_interpreter" INTERPRETER_URL = Settings().get('tools.python_interpreter.url') @classmethod def is_available(cls) -> bool: return cls.INTERPRETER_URL is not None + @classmethod + def get_tool_definition(cls) -> ToolDefinition: + return ToolDefinition( + name=cls.ID, + display_name="Python Interpreter", + implementation=cls, + parameter_definitions={ + "code": { + "description": ( + "Python code to execute using the Python interpreter with no internet access. " + "Do not generate code that tries to open files directly, instead use file contents passed to the interpreter, " + "then print output or save output to a file." + ), + "type": "str", + "required": True, + } + }, + is_visible=True, + is_available=cls.is_available(), + error_message=cls.generate_error_message(), + category=ToolCategory.Function, + description=( + "Executes python code and returns the result. The code runs " + "in a static sandbox without internet access and without interactive mode, " + "so print output or save output to a file." + ), + ) + async def call(self, parameters: dict, ctx: Any, **kwargs: Any): if not self.INTERPRETER_URL: raise Exception("Python Interpreter tool called while URL not set") diff --git a/src/backend/tools/slack/tool.py b/src/backend/tools/slack/tool.py index c1adee118e..35e78f0aea 100644 --- a/src/backend/tools/slack/tool.py +++ b/src/backend/tools/slack/tool.py @@ -2,8 +2,10 @@ from backend.config.settings import Settings from backend.crud import tool_auth as tool_auth_crud +from backend.schemas.tool import ToolCategory, ToolDefinition from backend.services.logger.utils import LoggerFactory from backend.tools.base import BaseTool +from backend.tools.slack.auth import SlackAuth from backend.tools.slack.constants import SEARCH_LIMIT, SLACK_TOOL_ID from backend.tools.slack.utils import get_slack_service @@ -15,7 +17,7 @@ class SlackTool(BaseTool): Tool that searches Slack for messages and files based on a query. """ - NAME = SLACK_TOOL_ID + ID = SLACK_TOOL_ID CLIENT_ID = Settings().get('tools.slack.client_id') CLIENT_SECRET = Settings().get('tools.slack.client_secret') @@ -23,6 +25,27 @@ class SlackTool(BaseTool): def is_available(cls) -> bool: return cls.CLIENT_ID is not None and cls.CLIENT_SECRET is not None + @classmethod + def get_tool_definition(cls) -> ToolDefinition: + return ToolDefinition( + name=cls.ID, + display_name="Slack", + implementation=cls, + parameter_definitions={ + "query": { + "description": "Query to search slack.", + "type": "str", + "required": True, + } + }, + is_visible=True, + is_available=cls.is_available(), + auth_implementation=SlackAuth, + error_message=cls.generate_error_message(), + category=ToolCategory.DataLoader, + description="Returns a list of relevant document snippets from slack.", + ) + @classmethod def _handle_tool_specific_errors(cls, error: Exception, **kwargs: Any) -> None: message = "[Slack] Tool Error: {}".format(str(error)) diff --git a/src/backend/tools/tavily_search.py b/src/backend/tools/tavily_search.py index abf30db883..0750a2517b 100644 --- a/src/backend/tools/tavily_search.py +++ b/src/backend/tools/tavily_search.py @@ -6,12 +6,13 @@ from backend.database_models.database import DBSessionDep from backend.model_deployments.base import BaseDeployment from backend.schemas.agent import AgentToolMetadataArtifactsType +from backend.schemas.tool import ToolCategory, ToolDefinition from backend.tools.base import BaseTool from backend.tools.utils.mixins import WebSearchFilteringMixin class TavilyWebSearch(BaseTool, WebSearchFilteringMixin): - NAME = "tavily_web_search" + ID = "tavily_web_search" TAVILY_API_KEY = Settings().get('tools.tavily_web_search.api_key') POST_RERANK_MAX_RESULTS = 6 @@ -22,6 +23,26 @@ def __init__(self): def is_available(cls) -> bool: return cls.TAVILY_API_KEY is not None + @classmethod + def get_tool_definition(cls) -> ToolDefinition: + return ToolDefinition( + name=cls.ID, + display_name="Web Search", + implementation=cls, + parameter_definitions={ + "query": { + "description": "Query to search the internet with", + "type": "str", + "required": True, + } + }, + is_visible=False, + is_available=cls.is_available(), + error_message=cls.generate_error_message(), + category=ToolCategory.WebSearch, + description="Returns a list of relevant document snippets for a textual query retrieved from the internet.", + ) + async def call( self, parameters: dict, ctx: Any, session: DBSessionDep, **kwargs: Any ) -> List[Dict[str, Any]]: @@ -57,19 +78,22 @@ async def call( # Append original search result expanded.append(result) - # Get other snippets - snippets = result["raw_content"].split("\n") - for snippet in snippets: - if result["content"] != snippet: - if len(snippet.split()) <= 10: - continue # Skip snippets with less than 10 words - - new_result = { - "url": result["url"], - "title": result["title"], - "content": snippet.strip(), - } - expanded.append(new_result) + # Retrieve snippets from raw content if exists + raw_content = result["raw_content"] + if raw_content: + # Get other snippets + snippets = result["raw_content"].split("\n") + for snippet in snippets: + if result["content"] != snippet: + if len(snippet.split()) <= 10: + continue # Skip snippets with less than 10 words + + new_result = { + "url": result["url"], + "title": result["title"], + "content": snippet.strip(), + } + expanded.append(new_result) reranked_results = await self.rerank_page_snippets( query, expanded, model=kwargs.get("model_deployment"), ctx=ctx, **kwargs diff --git a/src/backend/tools/utils/mixins.py b/src/backend/tools/utils/mixins.py index cd04024aab..ce4827140f 100644 --- a/src/backend/tools/utils/mixins.py +++ b/src/backend/tools/utils/mixins.py @@ -42,7 +42,7 @@ def get_filters( agent_tool_metadata = agent_tool_metadata_crud.get_agent_tool_metadata( db=session, agent_id=agent_id, - tool_name=self.NAME, + tool_name=self.ID, user_id=user_id, ) diff --git a/src/backend/tools/utils/tools_checkers.py b/src/backend/tools/utils/tools_checkers.py index 3a9acc66d3..f666cd7845 100644 --- a/src/backend/tools/utils/tools_checkers.py +++ b/src/backend/tools/utils/tools_checkers.py @@ -1,29 +1,29 @@ -from backend.schemas.tool import Category, ManagedTool -from community.config.tools import CommunityToolName +from backend.schemas.tool import ToolCategory, ToolDefinition +from community.config.tools import CommunityTool -def tool_has_category(tool: ManagedTool, category: Category) -> bool: +def tool_has_category(tool: ToolDefinition, category: ToolCategory) -> bool: """ Check if a tool has a specific category. Args: - tool (ManagedTool): The tool to check. - category (Category): The category to check for. + tool (ToolDefinition): The tool to check. + category (ToolCategory): The category to check for. Returns: - bool: True if the tool has the category, False otherwise. + bool: True if the tool has the category, False otherwise. """ return tool.category == category -def is_community_tool(tool: ManagedTool) -> bool: +def is_community_tool(tool: ToolDefinition) -> bool: """ Check if a tool is a community tool. Args: - tool (ManagedTool): The tool to check. + tool (ToolDefinition): The tool to check. Returns: - bool: True if the tool is a community tool, False otherwise. + bool: True if the tool is a community tool, False otherwise. """ - return tool.name in CommunityToolName + return tool.name in CommunityTool diff --git a/src/backend/tools/web_scrape.py b/src/backend/tools/web_scrape.py index 66ccf20f71..5479e951fe 100644 --- a/src/backend/tools/web_scrape.py +++ b/src/backend/tools/web_scrape.py @@ -4,6 +4,7 @@ import aiohttp from langchain_text_splitters import MarkdownHeaderTextSplitter +from backend.schemas.tool import ToolCategory, ToolDefinition from backend.services.logger.utils import LoggerFactory from backend.tools.base import BaseTool @@ -11,7 +12,7 @@ class WebScrapeTool(BaseTool): - NAME = "web_scrape" + ID = "web_scrape" ENDPOINT: ClassVar[str] = "http://co-reader" ENABLE_CHUNKING: ClassVar[bool] = True @@ -19,6 +20,31 @@ class WebScrapeTool(BaseTool): def is_available(cls) -> bool: return True + @classmethod + def get_tool_definition(cls) -> ToolDefinition: + return ToolDefinition( + name=cls.ID, + display_name="Web Scrape", + implementation=cls, + parameter_definitions={ + "url": { + "description": "The url to scrape.", + "type": "str", + "required": True, + }, + "query": { + "description": "The query to use to select the most relevant passages to return. Using an empty string will return the passages in the order they appear on the webpage", + "type": "str", + "required": False, + }, + }, + is_visible=False, + is_available=cls.is_available(), + error_message=cls.generate_error_message(), + category=ToolCategory.DataLoader, + description="Scrape and returns the textual contents of a webpage as a list of passages for a given url.", + ) + async def call( self, parameters: dict, ctx: Any, **kwargs: Any ) -> List[Dict[str, Any]]: diff --git a/src/community/config/tools.py b/src/community/config/tools.py index 7ea673690f..8b97859bb8 100644 --- a/src/community/config/tools.py +++ b/src/community/config/tools.py @@ -1,137 +1,31 @@ -from enum import StrEnum +from enum import Enum +from backend.schemas.tool import ToolDefinition from community.tools import ( ArxivRetriever, - Category, ClinicalTrials, ConnectorRetriever, LlamaIndexUploadPDFRetriever, - ManagedTool, PubMedRetriever, WolframAlpha, ) -class CommunityToolName(StrEnum): - Arxiv = ArxivRetriever.NAME - Connector = ConnectorRetriever.NAME - Pub_Med = PubMedRetriever.NAME - File_Upload_LlamaIndex = LlamaIndexUploadPDFRetriever.NAME - Wolfram_Alpha = WolframAlpha.NAME - ClinicalTrials = ClinicalTrials.NAME +class CommunityTool(Enum): + Arxiv = ArxivRetriever + Connector = ConnectorRetriever + Pub_Med = PubMedRetriever + File_Upload_LlamaIndex = LlamaIndexUploadPDFRetriever + Wolfram_Alpha = WolframAlpha + ClinicalTrials = ClinicalTrials -COMMUNITY_TOOLS = { - CommunityToolName.Arxiv: ManagedTool( - display_name="Arxiv", - implementation=ArxivRetriever, - parameter_definitions={ - "query": { - "description": "Query for retrieval.", - "type": "str", - "required": True, - } - }, - is_visible=True, - is_available=ArxivRetriever.is_available(), - error_message="ArxivRetriever is not available.", - category=Category.DataLoader, - description="Retrieves documents from Arxiv.", - ), - CommunityToolName.Connector: ManagedTool( - display_name="Example Connector", - implementation=ConnectorRetriever, - is_visible=True, - is_available=ConnectorRetriever.is_available(), - error_message="ConnectorRetriever is not available.", - category=Category.DataLoader, - description="Connects to a data source.", - ), - CommunityToolName.Pub_Med: ManagedTool( - display_name="PubMed", - implementation=PubMedRetriever, - parameter_definitions={ - "query": { - "description": "Query for retrieval.", - "type": "str", - "required": True, - } - }, - is_visible=True, - is_available=PubMedRetriever.is_available(), - error_message="PubMedRetriever is not available.", - category=Category.DataLoader, - description="Retrieves documents from Pub Med.", - ), - CommunityToolName.File_Upload_LlamaIndex: ManagedTool( - display_name="Llama File Reader", - implementation=LlamaIndexUploadPDFRetriever, - parameter_definitions={ - "query": { - "description": "Query for retrieval.", - "type": "str", - "required": True, - }, - "files": { - "description": "A list of files represented as tuples of (filename, file ID) to search over", - "type": "list[tuple[str, str]]", - "required": True, - }, +def get_community_tools() -> dict[str, ToolDefinition]: + # Get list of implementations from Tool Enum + tool_classes = [tool.value for tool in CommunityTool] + # Generate dictionary of ToolDefinitions keyed by Tool ID + community_tools = { + tool.ID: tool.get_tool_definition() for tool in tool_classes + } - }, - is_visible=True, - is_available=LlamaIndexUploadPDFRetriever.is_available(), - error_message="LlamaIndexUploadPDFRetriever is not available.", - category=Category.FileLoader, - description="Retrieves the most relevant documents from the uploaded files based on the query using Llama Index.", - ), - CommunityToolName.Wolfram_Alpha: ManagedTool( - display_name="Wolfram Alpha", - implementation=WolframAlpha, - is_visible=False, - is_available=WolframAlpha.is_available(), - error_message="WolframAlphaFunctionTool is not available, please set tools.wolfram_alpha.app_id in secrets.yaml", - category=Category.Function, - description="Evaluate arithmetic expressions.", - ), - CommunityToolName.ClinicalTrials: ManagedTool( - display_name="Clinical Trials", - implementation=ClinicalTrials, - is_visible=True, - is_available=ClinicalTrials.is_available(), - error_message="ClinicalTrialsTool is not available.", - category=Category.Function, - description="Retrieves clinical studies from ClinicalTrials.gov.", - parameter_definitions={ - "condition": { - "description": "Filters clinical studies to a specified disease or condition", - "type": "str", - "required": False, - }, - "location": { - "description": "Filters clinical studies to a specified city, state, or country.", - "type": "str", - "required": False, - }, - "intervention": { - "description": "Filters clinical studies to a specified drug or treatment.", - "type": "str", - "required": False, - }, - "is_recruiting": { - "description": "Filters clinical studies to those that are actively recruiting.", - "type": "bool", - "required": False, - }, - }, - ), -} - -# For main.py cli setup script -COMMUNITY_TOOLS_SETUP = { - CommunityToolName.Wolfram_Alpha: { - "secrets": { - "WOLFRAM_APP_ID": None, # default value - }, - }, -} + return community_tools diff --git a/src/community/tools/__init__.py b/src/community/tools/__init__.py index 86a0013172..1cffba1972 100644 --- a/src/community/tools/__init__.py +++ b/src/community/tools/__init__.py @@ -1,5 +1,3 @@ -from backend.schemas.tool import Category, ManagedTool -from backend.tools.base import BaseTool from community.tools.arxiv import ArxivRetriever from community.tools.clinicaltrials import ClinicalTrials from community.tools.connector import ConnectorRetriever @@ -14,7 +12,4 @@ "ConnectorRetriever", "LlamaIndexUploadPDFRetriever", "PubMedRetriever", - "Category", - "ManagedTool", - "BaseTool", ] diff --git a/src/community/tools/arxiv.py b/src/community/tools/arxiv.py index 7d92d87549..ce5cfac71c 100644 --- a/src/community/tools/arxiv.py +++ b/src/community/tools/arxiv.py @@ -2,11 +2,12 @@ from langchain_community.utilities import ArxivAPIWrapper -from community.tools import BaseTool +from backend.schemas.tool import ToolCategory, ToolDefinition +from backend.tools.base import BaseTool class ArxivRetriever(BaseTool): - NAME = "arxiv" + ID = "arxiv" def __init__(self): self.client = ArxivAPIWrapper() @@ -15,6 +16,26 @@ def __init__(self): def is_available(cls) -> bool: return True + @classmethod + def get_tool_definition(cls) -> ToolDefinition: + return ToolDefinition( + name=cls.ID, + display_name="Arxiv", + implementation=cls, + parameter_definitions={ + "query": { + "description": "Query for retrieval.", + "type": "str", + "required": True, + } + }, + is_visible=False, + is_available=cls.is_available(), + error_message=cls.generate_error_message(), + category=ToolCategory.DataLoader, + description="Retrieves documents from Arxiv.", + ) + async def call(self, parameters: dict, **kwargs: Any) -> List[Dict[str, Any]]: query = parameters.get("query", "") result = self.client.run(query) diff --git a/src/community/tools/clinicaltrials.py b/src/community/tools/clinicaltrials.py index 0a8af52aed..3db15271ac 100644 --- a/src/community/tools/clinicaltrials.py +++ b/src/community/tools/clinicaltrials.py @@ -2,7 +2,8 @@ import requests -from community.tools import BaseTool +from backend.schemas.tool import ToolCategory, ToolDefinition +from backend.tools.base import BaseTool class ClinicalTrials(BaseTool): @@ -12,7 +13,7 @@ class ClinicalTrials(BaseTool): See: https://clinicaltrials.gov/data-api/api """ - NAME = "clinical_trials" + ID = "clinical_trials" def __init__(self, url="https://clinicaltrials.gov/api/v2/studies"): self._url = url @@ -21,6 +22,41 @@ def __init__(self, url="https://clinicaltrials.gov/api/v2/studies"): def is_available(cls) -> bool: return True + @classmethod + def get_tool_definition(cls) -> ToolDefinition: + return ToolDefinition( + name=cls.ID, + display_name="Clinical Trials", + implementation=cls, + is_visible=False, + is_available=cls.is_available(), + error_message=cls.generate_error_message(), + category=ToolCategory.Function, + description="Retrieves clinical studies from ClinicalTrials.gov.", + parameter_definitions={ + "condition": { + "description": "Filters clinical studies to a specified disease or condition", + "type": "str", + "required": False, + }, + "location": { + "description": "Filters clinical studies to a specified city, state, or country.", + "type": "str", + "required": False, + }, + "intervention": { + "description": "Filters clinical studies to a specified drug or treatment.", + "type": "str", + "required": False, + }, + "is_recruiting": { + "description": "Filters clinical studies to those that are actively recruiting.", + "type": "bool", + "required": False, + }, + }, + ) + async def call( self, parameters: Dict[str, Any], n_max_studies: int = 10, **kwargs ) -> List[Dict[str, Any]]: diff --git a/src/community/tools/connector.py b/src/community/tools/connector.py index a2af411a6a..b19445ddad 100644 --- a/src/community/tools/connector.py +++ b/src/community/tools/connector.py @@ -2,7 +2,8 @@ import requests -from community.tools import BaseTool +from backend.schemas.tool import ToolCategory, ToolDefinition +from backend.tools.base import BaseTool """ Plug in your Connector configuration here. For example: @@ -10,28 +11,43 @@ Url: http://example_connector.com/search Auth: Bearer token for the connector +To see SSO examples, check out our Google Drive or Slack tool implementations + More details: https://docs.cohere.com/docs/connectors """ class ConnectorRetriever(BaseTool): - NAME = "example_connector" + ID = "example_connector" - def __init__(self, url: str, auth: str): + def __init__(self, url: str, api_key: str): self.url = url - self.auth = auth + self.api_key = api_key @classmethod def is_available(cls) -> bool: - return True + return False + + @classmethod + def get_tool_definition(cls) -> ToolDefinition: + return ToolDefinition( + name=cls.ID, + display_name="Example Connector Template - Do not use", + implementation=ConnectorRetriever, + is_visible=False, + is_available=cls.is_available(), + error_message=cls.generate_error_message(), + category=ToolCategory.DataLoader, + description="Example connector for a data source using a basic API.", + ) async def call(self, parameters: dict, **kwargs: Any) -> List[Dict[str, Any]]: body = {"query": parameters} headers = { "Content-Type": "application/json", - "Authorization": f"Bearer {self.auth}", + "Authorization": f"Bearer {self.api_key}", } - response = requests.post(self.url, json=body, headers=headers) + response = requests.get(self.url, json=body, headers=headers) return response.json()["results"] diff --git a/src/community/tools/llama_index.py b/src/community/tools/llama_index.py index 6cdef8da4c..aafdc1b491 100644 --- a/src/community/tools/llama_index.py +++ b/src/community/tools/llama_index.py @@ -7,7 +7,8 @@ import backend.crud.file as file_crud from backend.config import Settings -from community.tools import BaseTool +from backend.schemas.tool import ToolCategory, ToolDefinition +from backend.tools.base import BaseTool """ Plug in your llama index retrieval implementation here. @@ -25,7 +26,7 @@ class LlamaIndexUploadPDFRetriever(BaseTool): This requires llama_index package to be installed. """ - NAME = "file_reader_llamaindex" + ID = "file_reader_llamaindex" CHUNK_SIZE = 512 def __init__(self): @@ -39,11 +40,39 @@ def _get_embedding(self, embed_type): input_type=embed_type, ) - @classmethod def is_available(cls) -> bool: return True + @classmethod + def get_tool_definition(cls) -> ToolDefinition: + return ToolDefinition( + name=cls.ID, + display_name="Llama File Reader", + implementation=cls, + parameter_definitions={ + "query": { + "description": "Query for retrieval.", + "type": "str", + "required": True, + }, + "files": { + "description": "A list of files represented as tuples of (filename, file ID) to search over", + "type": "list[tuple[str, str]]", + "required": True, + }, + + }, + is_visible=False, + is_available=cls.is_available(), + error_message=cls.generate_error_message(), + category=ToolCategory.FileLoader, + description=( + "Retrieves the most relevant documents from the uploaded " + "files based on the query using Llama Index." + ) + ) + async def call( self, parameters: dict, ctx: Any, **kwargs: Any ) -> List[Dict[str, Any]]: diff --git a/src/community/tools/pub_med.py b/src/community/tools/pub_med.py index 1ce46c7f80..6968e57ea3 100644 --- a/src/community/tools/pub_med.py +++ b/src/community/tools/pub_med.py @@ -2,11 +2,12 @@ from langchain_community.tools.pubmed.tool import PubmedQueryRun -from community.tools import BaseTool +from backend.schemas.tool import ToolCategory, ToolDefinition +from backend.tools.base import BaseTool class PubMedRetriever(BaseTool): - NAME = "pub_med" + ID = "pub_med" def __init__(self): self.client = PubmedQueryRun() @@ -15,6 +16,26 @@ def __init__(self): def is_available(cls) -> bool: return True + @classmethod + def get_tool_definition(cls) -> ToolDefinition: + return ToolDefinition( + name=cls.ID, + display_name="Pub Med", + implementation=cls, + parameter_definitions={ + "query": { + "description": "Query for retrieval.", + "type": "str", + "required": True, + } + }, + is_visible=False, + is_available=cls.is_available(), + error_message=cls.generate_error_message(), + category=ToolCategory.DataLoader, + description="Retrieves documents from Pub Med.", + ) + async def call(self, parameters: dict, **kwargs: Any) -> List[Dict[str, Any]]: query = parameters.get("query", "") result = self.client.invoke(query) diff --git a/src/community/tools/wolfram.py b/src/community/tools/wolfram.py index 9fc022ebab..dc098e77ed 100644 --- a/src/community/tools/wolfram.py +++ b/src/community/tools/wolfram.py @@ -3,7 +3,8 @@ from langchain_community.utilities.wolfram_alpha import WolframAlphaAPIWrapper from backend.config.settings import Settings -from community.tools import BaseTool +from backend.schemas.tool import ToolCategory, ToolDefinition +from backend.tools.base import BaseTool class WolframAlpha(BaseTool): @@ -13,7 +14,7 @@ class WolframAlpha(BaseTool): See: https://python.langchain.com/docs/integrations/tools/wolfram_alpha/ """ - NAME = "wolfram_alpha" + ID = "wolfram_alpha" wolfram_app_id = Settings().get('tools.wolfram_alpha.app_id') @@ -25,6 +26,19 @@ def __init__(self): def is_available(cls) -> bool: return cls.wolfram_app_id is not None + @classmethod + def get_tool_definition(cls) -> ToolDefinition: + return ToolDefinition( + name=cls.ID, + display_name="Wolfram Alpha", + implementation=cls, + is_visible=True, + is_available=cls.is_available(), + error_message=cls.generate_error_message(), + category=ToolCategory.Function, + description="Evaluate arithmetic expressions using Wolfram Alpha.", + ) + async def call(self, parameters: dict, **kwargs: Any) -> List[Dict[str, Any]]: to_evaluate = parameters.get("expression", "") result = self.tool.run(to_evaluate) diff --git a/src/interfaces/assistants_web/src/app/(main)/(chat)/Chat.tsx b/src/interfaces/assistants_web/src/app/(main)/(chat)/Chat.tsx index 68fc43bbee..9d0a5aa71b 100644 --- a/src/interfaces/assistants_web/src/app/(main)/(chat)/Chat.tsx +++ b/src/interfaces/assistants_web/src/app/(main)/(chat)/Chat.tsx @@ -2,7 +2,7 @@ import { useEffect } from 'react'; -import { Document, ManagedTool } from '@/cohere-client'; +import { Document, ToolDefinition } from '@/cohere-client'; import { Conversation, ConversationError } from '@/components/Conversation'; import { TOOL_PYTHON_INTERPRETER_ID } from '@/constants'; import { useAgent, useAvailableTools, useConversation, useListTools } from '@/hooks'; @@ -24,7 +24,7 @@ const Chat: React.FC<{ agentId?: string; conversationId?: string }> = ({ const { setConversation } = useConversationStore(); const { addCitation, saveOutputFiles } = useCitationsStore(); const { setParams, resetFileParams } = useParamsStore(); - const { availableTools } = useAvailableTools({ agent, managedTools: tools }); + const { availableTools } = useAvailableTools({ agent, allTools: tools }); const { data: conversation, @@ -44,7 +44,7 @@ const Chat: React.FC<{ agentId?: string; conversationId?: string }> = ({ .map((name) => (tools ?? [])?.find((t) => t.name === name)) .filter( (t) => t !== undefined && availableTools.some((at) => at.name === t?.name) - ) as ManagedTool[]); + ) as ToolDefinition[]); const fileIds = conversation?.files.map((file) => file.id); diff --git a/src/interfaces/assistants_web/src/cohere-client/generated/schemas.gen.ts b/src/interfaces/assistants_web/src/cohere-client/generated/schemas.gen.ts index 81dfd72bcd..2043c8e93d 100644 --- a/src/interfaces/assistants_web/src/cohere-client/generated/schemas.gen.ts +++ b/src/interfaces/assistants_web/src/cohere-client/generated/schemas.gen.ts @@ -270,12 +270,6 @@ export const $Body_batch_upload_file_v1_conversations_batch_upload_file_post = { title: 'Body_batch_upload_file_v1_conversations_batch_upload_file_post', } as const; -export const $Category = { - type: 'string', - enum: ['Data loader', 'File loader', 'Function', 'Web search'], - title: 'Category', -} as const; - export const $ChatMessage = { properties: { role: { @@ -1919,114 +1913,6 @@ export const $Logout = { title: 'Logout', } as const; -export const $ManagedTool = { - properties: { - name: { - anyOf: [ - { - type: 'string', - }, - { - type: 'null', - }, - ], - title: 'Name', - default: '', - }, - display_name: { - type: 'string', - title: 'Display Name', - default: '', - }, - description: { - anyOf: [ - { - type: 'string', - }, - { - type: 'null', - }, - ], - title: 'Description', - default: '', - }, - parameter_definitions: { - anyOf: [ - { - type: 'object', - }, - { - type: 'null', - }, - ], - title: 'Parameter Definitions', - default: {}, - }, - kwargs: { - type: 'object', - title: 'Kwargs', - default: {}, - }, - is_visible: { - type: 'boolean', - title: 'Is Visible', - default: false, - }, - is_available: { - type: 'boolean', - title: 'Is Available', - default: false, - }, - error_message: { - anyOf: [ - { - type: 'string', - }, - { - type: 'null', - }, - ], - title: 'Error Message', - default: '', - }, - category: { - $ref: '#/components/schemas/Category', - default: 'Data loader', - }, - is_auth_required: { - type: 'boolean', - title: 'Is Auth Required', - default: false, - }, - auth_url: { - anyOf: [ - { - type: 'string', - }, - { - type: 'null', - }, - ], - title: 'Auth Url', - default: '', - }, - token: { - anyOf: [ - { - type: 'string', - }, - { - type: 'null', - }, - ], - title: 'Token', - default: '', - }, - }, - type: 'object', - title: 'ManagedTool', -} as const; - export const $Message = { properties: { text: { @@ -3228,6 +3114,120 @@ export const $ToolCallDelta = { title: 'ToolCallDelta', } as const; +export const $ToolCategory = { + type: 'string', + enum: ['Data loader', 'File loader', 'Function', 'Web search'], + title: 'ToolCategory', +} as const; + +export const $ToolDefinition = { + properties: { + name: { + anyOf: [ + { + type: 'string', + }, + { + type: 'null', + }, + ], + title: 'Name', + default: '', + }, + display_name: { + type: 'string', + title: 'Display Name', + default: '', + }, + description: { + anyOf: [ + { + type: 'string', + }, + { + type: 'null', + }, + ], + title: 'Description', + default: '', + }, + parameter_definitions: { + anyOf: [ + { + type: 'object', + }, + { + type: 'null', + }, + ], + title: 'Parameter Definitions', + default: {}, + }, + kwargs: { + type: 'object', + title: 'Kwargs', + default: {}, + }, + is_visible: { + type: 'boolean', + title: 'Is Visible', + default: false, + }, + is_available: { + type: 'boolean', + title: 'Is Available', + default: false, + }, + error_message: { + anyOf: [ + { + type: 'string', + }, + { + type: 'null', + }, + ], + title: 'Error Message', + default: '', + }, + category: { + $ref: '#/components/schemas/ToolCategory', + default: 'Data loader', + }, + is_auth_required: { + type: 'boolean', + title: 'Is Auth Required', + default: false, + }, + auth_url: { + anyOf: [ + { + type: 'string', + }, + { + type: 'null', + }, + ], + title: 'Auth Url', + default: '', + }, + token: { + anyOf: [ + { + type: 'string', + }, + { + type: 'null', + }, + ], + title: 'Token', + default: '', + }, + }, + type: 'object', + title: 'ToolDefinition', +} as const; + export const $ToolInputType = { type: 'string', enum: ['QUERY', 'CODE'], diff --git a/src/interfaces/assistants_web/src/cohere-client/generated/services.gen.ts b/src/interfaces/assistants_web/src/cohere-client/generated/services.gen.ts index f281bd5624..2bc613de69 100644 --- a/src/interfaces/assistants_web/src/cohere-client/generated/services.gen.ts +++ b/src/interfaces/assistants_web/src/cohere-client/generated/services.gen.ts @@ -293,7 +293,7 @@ export class DefaultService { * If completed, the corresponding ToolAuth for the requesting user is removed from the DB. * * Args: - * tool_id (str): Tool ID to be deleted for the user. (eg. google_drive) Should be one of the values listed in the ToolName string enum class. + * tool_id (str): Tool ID to be deleted for the user. (eg. google_drive) Should be one of the values listed in the Tool enum. * request (Request): current Request object. * session (DBSessionDep): Database session. * ctx (Context): Context object. @@ -990,10 +990,10 @@ export class DefaultService { * agent_id (str): Agent ID. * ctx (Context): Context object. * Returns: - * list[ManagedTool]: List of available tools. + * list[ToolDefinition]: List of available tools. * @param data The data for the request. * @param data.agentId - * @returns ManagedTool Successful Response + * @returns ToolDefinition Successful Response * @throws ApiError */ public listToolsV1ToolsGet( @@ -1785,7 +1785,7 @@ export class DefaultService { * session (DBSessionDep): Database session. * * Returns: - * list[ManagedTool]: List of available organizations. + * list[Organization]: List of available organizations. * @returns Organization Successful Response * @throws ApiError */ @@ -1866,9 +1866,10 @@ export class DefaultService { * Args: * organization_id (str): Tool ID. * session (DBSessionDep): Database session. + * ctx: Context. * * Returns: - * ManagedTool: Organization with the given ID. + * Organization: Organization with the given ID. * @param data The data for the request. * @param data.organizationId * @returns Organization Successful Response diff --git a/src/interfaces/assistants_web/src/cohere-client/generated/types.gen.ts b/src/interfaces/assistants_web/src/cohere-client/generated/types.gen.ts index c6cb588614..c8329f7488 100644 --- a/src/interfaces/assistants_web/src/cohere-client/generated/types.gen.ts +++ b/src/interfaces/assistants_web/src/cohere-client/generated/types.gen.ts @@ -56,13 +56,6 @@ export type Body_batch_upload_file_v1_conversations_batch_upload_file_post = { files: Array; }; -export enum Category { - DATA_LOADER = 'Data loader', - FILE_LOADER = 'File loader', - FUNCTION = 'Function', - WEB_SEARCH = 'Web search', -} - /** * A list of previous messages between the user and the model, meant to give the model conversational context for responding to the user's message. */ @@ -393,25 +386,6 @@ export type Login = { export type Logout = unknown; -export type ManagedTool = { - name?: string | null; - display_name?: string; - description?: string | null; - parameter_definitions?: { - [key: string]: unknown; - } | null; - kwargs?: { - [key: string]: unknown; - }; - is_visible?: boolean; - is_available?: boolean; - error_message?: string | null; - category?: Category; - is_auth_required?: boolean; - auth_url?: string | null; - token?: string | null; -}; - export type Message = { text: string; id: string; @@ -677,6 +651,32 @@ export type ToolCallDelta = { parameters: string | null; }; +export enum ToolCategory { + DATA_LOADER = 'Data loader', + FILE_LOADER = 'File loader', + FUNCTION = 'Function', + WEB_SEARCH = 'Web search', +} + +export type ToolDefinition = { + name?: string | null; + display_name?: string; + description?: string | null; + parameter_definitions?: { + [key: string]: unknown; + } | null; + kwargs?: { + [key: string]: unknown; + }; + is_visible?: boolean; + is_available?: boolean; + error_message?: string | null; + category?: ToolCategory; + is_auth_required?: boolean; + auth_url?: string | null; + token?: string | null; +}; + /** * Type of input passed to the tool */ @@ -961,7 +961,7 @@ export type ListToolsV1ToolsGetData = { agentId?: string | null; }; -export type ListToolsV1ToolsGetResponse = Array; +export type ListToolsV1ToolsGetResponse = Array; export type CreateDeploymentV1DeploymentsPostData = { requestBody: DeploymentCreate; @@ -1611,7 +1611,7 @@ export type $OpenApiTs = { /** * Successful Response */ - 200: Array; + 200: Array; /** * Validation Error */ diff --git a/src/interfaces/assistants_web/src/components/AgentSettingsForm/ToolsStep.tsx b/src/interfaces/assistants_web/src/components/AgentSettingsForm/ToolsStep.tsx index 595dd0812b..b712c5d890 100644 --- a/src/interfaces/assistants_web/src/components/AgentSettingsForm/ToolsStep.tsx +++ b/src/interfaces/assistants_web/src/components/AgentSettingsForm/ToolsStep.tsx @@ -1,12 +1,12 @@ import Link from 'next/link'; -import { ManagedTool } from '@/cohere-client'; +import { ToolDefinition } from '@/cohere-client'; import { StatusConnection } from '@/components/AgentSettingsForm/StatusConnection'; import { Button, Icon, IconName, Switch, Text } from '@/components/UI'; import { AGENT_SETTINGS_TOOLS, TOOL_FALLBACK_ICON, TOOL_ID_TO_DISPLAY_INFO } from '@/constants'; type Props = { - tools?: ManagedTool[]; + tools?: ToolDefinition[]; activeTools?: string[]; setActiveTools: (tools: string[]) => void; handleAuthButtonClick: (toolName: string) => void; diff --git a/src/interfaces/assistants_web/src/components/Composer/Composer.tsx b/src/interfaces/assistants_web/src/components/Composer/Composer.tsx index 418109480d..90580c0411 100644 --- a/src/interfaces/assistants_web/src/components/Composer/Composer.tsx +++ b/src/interfaces/assistants_web/src/components/Composer/Composer.tsx @@ -3,7 +3,7 @@ import { useResizeObserver } from '@react-hookz/web'; import React, { useEffect, useRef, useState } from 'react'; -import { AgentPublic, ManagedTool } from '@/cohere-client'; +import { AgentPublic, ToolDefinition } from '@/cohere-client'; import { ComposerError, ComposerFiles, ComposerToolbar } from '@/components/Composer'; import { DragDropFileInput, Icon, STYLE_LEVEL_TO_CLASSES } from '@/components/UI'; import { CHAT_COMPOSER_TEXTAREA_ID } from '@/constants'; @@ -21,7 +21,7 @@ type Props = { onChange: (message: string) => void; onUploadFile: (files: File[]) => void; agent?: AgentPublic; - tools?: ManagedTool[]; + tools?: ToolDefinition[]; chatWindowRef?: React.RefObject; lastUserMessage?: ChatMessage; }; @@ -42,7 +42,7 @@ export const Composer: React.FC = ({ const breakpoint = useBreakpoint(); const isSmallBreakpoint = breakpoint === 'sm'; const textareaRef = useRef(null); - const { unauthedTools } = useAvailableTools({ agent, managedTools: tools }); + const { unauthedTools } = useAvailableTools({ agent, allTools: tools }); const isToolAuthRequired = unauthedTools.length > 0; const [chatWindowHeight, setChatWindowHeight] = useState(0); diff --git a/src/interfaces/assistants_web/src/components/Composer/ComposerToolbar.tsx b/src/interfaces/assistants_web/src/components/Composer/ComposerToolbar.tsx index 75df594412..4b78be2838 100644 --- a/src/interfaces/assistants_web/src/components/Composer/ComposerToolbar.tsx +++ b/src/interfaces/assistants_web/src/components/Composer/ComposerToolbar.tsx @@ -2,13 +2,13 @@ import React from 'react'; -import { AgentPublic, ManagedTool } from '@/cohere-client'; +import { AgentPublic, ToolDefinition } from '@/cohere-client'; import { DataSourceMenu, FilesMenu } from '@/components/Composer'; import { cn } from '@/utils'; type Props = { agent?: AgentPublic; - tools?: ManagedTool[]; + tools?: ToolDefinition[]; onUploadFile: (files: File[]) => void; }; diff --git a/src/interfaces/assistants_web/src/components/Composer/DataSourceMenu.tsx b/src/interfaces/assistants_web/src/components/Composer/DataSourceMenu.tsx index 1bbdd5099d..24761cb6d8 100644 --- a/src/interfaces/assistants_web/src/components/Composer/DataSourceMenu.tsx +++ b/src/interfaces/assistants_web/src/components/Composer/DataSourceMenu.tsx @@ -3,7 +3,7 @@ import { Popover, PopoverButton, PopoverPanel } from '@headlessui/react'; import React from 'react'; -import { AgentPublic, ManagedTool } from '@/cohere-client'; +import { AgentPublic, ToolDefinition } from '@/cohere-client'; import { Icon, Switch, Text } from '@/components/UI'; import { useAvailableTools, useBrandedColors } from '@/hooks'; import { useParamsStore } from '@/stores'; @@ -11,7 +11,7 @@ import { checkIsBaseAgent, cn, getToolIcon } from '@/utils'; export type Props = { agent?: AgentPublic; - tools?: ManagedTool[]; + tools?: ToolDefinition[]; }; /** @@ -23,7 +23,7 @@ export const DataSourceMenu: React.FC = ({ agent, tools }) => { } = useParamsStore(); const { availableTools, handleToggle } = useAvailableTools({ agent, - managedTools: tools, + allTools: tools, }); const { text, contrastText, border, bg } = useBrandedColors(agent?.id); diff --git a/src/interfaces/assistants_web/src/components/Conversation/Conversation.tsx b/src/interfaces/assistants_web/src/components/Conversation/Conversation.tsx index 4e6c1158e0..336033fb5c 100644 --- a/src/interfaces/assistants_web/src/components/Conversation/Conversation.tsx +++ b/src/interfaces/assistants_web/src/components/Conversation/Conversation.tsx @@ -2,7 +2,7 @@ import React, { useRef } from 'react'; -import { AgentPublic, ManagedTool } from '@/cohere-client'; +import { AgentPublic, ToolDefinition } from '@/cohere-client'; import { Composer } from '@/components/Composer'; import { Header } from '@/components/Conversation'; import { MessagingContainer, WelcomeGuideTooltip } from '@/components/MessagingContainer'; @@ -19,7 +19,7 @@ import { ChatMessage } from '@/types/message'; type Props = { startOptionsEnabled?: boolean; agent?: AgentPublic; - tools?: ManagedTool[]; + tools?: ToolDefinition[]; history?: ChatMessage[]; }; diff --git a/src/interfaces/assistants_web/src/components/MessagingContainer/AssistantTools.tsx b/src/interfaces/assistants_web/src/components/MessagingContainer/AssistantTools.tsx index 84b7c56482..9faad2f7bf 100644 --- a/src/interfaces/assistants_web/src/components/MessagingContainer/AssistantTools.tsx +++ b/src/interfaces/assistants_web/src/components/MessagingContainer/AssistantTools.tsx @@ -2,7 +2,7 @@ import React from 'react'; -import { AgentPublic, ManagedTool } from '@/cohere-client'; +import { AgentPublic, ToolDefinition } from '@/cohere-client'; import { WelcomeGuideTooltip } from '@/components/MessagingContainer'; import { Button, Icon, Text, ToggleCard } from '@/components/UI'; import { useAvailableTools } from '@/hooks'; @@ -13,7 +13,7 @@ import { checkIsBaseAgent, cn, getToolIcon } from '@/utils'; * @description Tools for the assistant to use in the conversation. */ export const AssistantTools: React.FC<{ - tools: ManagedTool[]; + tools: ToolDefinition[]; agent?: AgentPublic; className?: string; }> = ({ tools, agent, className = '' }) => { @@ -23,7 +23,7 @@ export const AssistantTools: React.FC<{ const enabledTools = paramTools ?? []; const { availableTools, unauthedTools, handleToggle } = useAvailableTools({ agent, - managedTools: tools, + allTools: tools, }); if (availableTools.length === 0) return null; diff --git a/src/interfaces/assistants_web/src/hooks/use-tools.ts b/src/interfaces/assistants_web/src/hooks/use-tools.ts index 584e437f1d..2c4959f9eb 100644 --- a/src/interfaces/assistants_web/src/hooks/use-tools.ts +++ b/src/interfaces/assistants_web/src/hooks/use-tools.ts @@ -3,7 +3,7 @@ import { useMemo } from 'react'; import useDrivePicker from 'react-google-drive-picker'; import type { PickerCallback } from 'react-google-drive-picker/dist/typeDefs'; -import { AgentPublic, ApiError, ManagedTool, useCohereClient } from '@/cohere-client'; +import { AgentPublic, ApiError, ToolDefinition, useCohereClient } from '@/cohere-client'; import { BASE_AGENT_EXCLUDED_TOOLS, DEFAULT_AGENT_TOOLS, TOOL_GOOGLE_DRIVE_ID } from '@/constants'; import { env } from '@/env.mjs'; import { useNotify } from '@/hooks'; @@ -13,7 +13,7 @@ import { checkIsBaseAgent } from '@/utils'; export const useListTools = (enabled: boolean = true) => { const client = useCohereClient(); - return useQuery({ + return useQuery({ queryKey: ['tools'], queryFn: async () => { const tools = await client.listTools({}); @@ -84,10 +84,10 @@ export const useOpenGoogleDrivePicker = (callbackFunction: (data: PickerCallback export const useAvailableTools = ({ agent, - managedTools, + allTools, }: { agent?: AgentPublic; - managedTools?: ManagedTool[]; + allTools?: ToolDefinition[]; }) => { const requiredTools = agent?.tools; @@ -106,14 +106,14 @@ export const useAvailableTools = ({ ) ?? []; const availableTools = useMemo(() => { - return (managedTools ?? []).filter( + return (allTools ?? []).filter( (t) => t.is_visible && t.is_available && (!requiredTools || requiredTools.some((rt) => rt === t.name)) && !(isBaseAgent && BASE_AGENT_EXCLUDED_TOOLS.some((rt) => rt === t.name)) ); - }, [managedTools, requiredTools]); + }, [allTools, requiredTools]); const handleToggle = (name: string, checked: boolean) => { const newParams: Partial = {