From e0cf5ff9383a1f189ee34cfe9f80cd8c424566d9 Mon Sep 17 00:00:00 2001 From: Scott <146760070+scott-cohere@users.noreply.github.com> Date: Thu, 11 Jul 2024 12:52:43 -0400 Subject: [PATCH] backend: web scraping in tools + fix metrics data use in tools (#393) * all changes * lint * code review changes * error msg * update tool descriptions * fix --- src/backend/chat/custom/custom.py | 1 + src/backend/config/routers.py | 5 +- src/backend/config/tools.py | 24 ++++++ src/backend/model_deployments/azure.py | 10 +-- src/backend/model_deployments/bedrock.py | 10 +-- .../model_deployments/cohere_platform.py | 4 +- src/backend/model_deployments/sagemaker.py | 6 +- .../model_deployments/single_container.py | 12 +-- src/backend/services/metrics.py | 6 +- src/backend/tools/__init__.py | 2 + src/backend/tools/tavily.py | 77 +++++++++++++++++-- src/backend/tools/web_scrape.py | 41 ++++++++++ 12 files changed, 158 insertions(+), 40 deletions(-) create mode 100644 src/backend/tools/web_scrape.py diff --git a/src/backend/chat/custom/custom.py b/src/backend/chat/custom/custom.py index c9ff923900..b33a383979 100644 --- a/src/backend/chat/custom/custom.py +++ b/src/backend/chat/custom/custom.py @@ -255,6 +255,7 @@ async def call_tools(self, chat_history, deployment_model, **kwargs: Any): session=kwargs.get("session"), model_deployment=deployment_model, user_id=kwargs.get("user_id"), + trace_id=kwargs.get("trace_id"), agent_id=kwargs.get("agent_id"), ) diff --git a/src/backend/config/routers.py b/src/backend/config/routers.py index 74970141de..12ecb56b13 100644 --- a/src/backend/config/routers.py +++ b/src/backend/config/routers.py @@ -75,8 +75,11 @@ class RouterName(StrEnum): ], }, RouterName.TOOL: { - "default": [], + "default": [ + Depends(get_session), + ], "auth": [ + Depends(get_session), Depends(validate_authorization), ], }, diff --git a/src/backend/config/tools.py b/src/backend/config/tools.py index 2c8b1562e2..dd6c16a65d 100644 --- a/src/backend/config/tools.py +++ b/src/backend/config/tools.py @@ -13,6 +13,7 @@ ReadFileTool, SearchFileTool, TavilyInternetSearch, + WebScrapeTool, ) """ @@ -35,6 +36,7 @@ class ToolName(StrEnum): Calculator = Calculator.NAME Tavily_Internet_Search = TavilyInternetSearch.NAME Google_Drive = GoogleDrive.NAME + Web_Scrape = WebScrapeTool.NAME ALL_TOOLS = { @@ -157,6 +159,28 @@ class ToolName(StrEnum): category=Category.DataLoader, description="Returns a list of relevant document snippets for the user's google drive.", ), + ToolName.Web_Scrape: ManagedTool( + name=ToolName.Web_Scrape, + display_name="Web Scrape", + implementation=WebScrapeTool, + parameter_definitions={ + "url": { + "description": "The url to scrape.", + "type": "str", + "required": True, + }, + "query": { + "description": "The query to use to select the most relevant passages to return. Using an empty string will return the passages in the order they appear on the webpage", + "type": "str", + "required": False, + }, + }, + is_visible=True, + is_available=WebScrapeTool.is_available(), + error_message="WebScrapeTool not available.", + category=Category.DataLoader, + description="Scrape and returns the textual contents of a webpage as a list of passages for a given url.", + ), } diff --git a/src/backend/model_deployments/azure.py b/src/backend/model_deployments/azure.py index b494933f33..d8cf4eeee1 100644 --- a/src/backend/model_deployments/azure.py +++ b/src/backend/model_deployments/azure.py @@ -63,26 +63,22 @@ def is_available(cls) -> bool: return all([os.environ.get(var) is not None for var in AZURE_ENV_VARS]) @collect_metrics_chat - async def invoke_chat(self, chat_request: CohereChatRequest, **kwargs: Any) -> Any: + async def invoke_chat(self, chat_request: CohereChatRequest) -> Any: response = self.client.chat( **chat_request.model_dump(exclude={"stream", "file_ids", "agent_id"}), - **kwargs, ) yield to_dict(response) @collect_metrics_chat_stream async def invoke_chat_stream( - self, chat_request: CohereChatRequest, **kwargs: Any + self, chat_request: CohereChatRequest ) -> AsyncGenerator[Any, Any]: stream = self.client.chat_stream( **chat_request.model_dump(exclude={"stream", "file_ids", "agent_id"}), - **kwargs, ) for event in stream: yield to_dict(event) - async def invoke_rerank( - self, query: str, documents: List[Dict[str, Any]], **kwargs: Any - ) -> Any: + async def invoke_rerank(self, query: str, documents: List[Dict[str, Any]]) -> Any: return None diff --git a/src/backend/model_deployments/bedrock.py b/src/backend/model_deployments/bedrock.py index 678f69ab89..755678f395 100644 --- a/src/backend/model_deployments/bedrock.py +++ b/src/backend/model_deployments/bedrock.py @@ -64,7 +64,7 @@ def is_available(cls) -> bool: return all([os.environ.get(var) is not None for var in BEDROCK_ENV_VARS]) @collect_metrics_chat - async def invoke_chat(self, chat_request: CohereChatRequest, **kwargs: Any) -> Any: + async def invoke_chat(self, chat_request: CohereChatRequest) -> Any: # bedrock accepts a subset of the chat request fields bedrock_chat_req = chat_request.model_dump( exclude={"tools", "conversation_id", "model", "stream"}, exclude_none=True @@ -72,13 +72,12 @@ async def invoke_chat(self, chat_request: CohereChatRequest, **kwargs: Any) -> A response = self.client.chat( **bedrock_chat_req, - **kwargs, ) yield to_dict(response) @collect_metrics_chat_stream async def invoke_chat_stream( - self, chat_request: CohereChatRequest, **kwargs: Any + self, chat_request: CohereChatRequest ) -> AsyncGenerator[Any, Any]: # bedrock accepts a subset of the chat request fields bedrock_chat_req = chat_request.model_dump( @@ -87,12 +86,9 @@ async def invoke_chat_stream( stream = self.client.chat_stream( **bedrock_chat_req, - **kwargs, ) for event in stream: yield to_dict(event) - async def invoke_rerank( - self, query: str, documents: List[Dict[str, Any]], **kwargs: Any - ) -> Any: + async def invoke_rerank(self, query: str, documents: List[Dict[str, Any]]) -> Any: return None diff --git a/src/backend/model_deployments/cohere_platform.py b/src/backend/model_deployments/cohere_platform.py index 71651808b4..15a0c5e2f9 100644 --- a/src/backend/model_deployments/cohere_platform.py +++ b/src/backend/model_deployments/cohere_platform.py @@ -73,7 +73,6 @@ def is_available(cls) -> bool: async def invoke_chat(self, chat_request: CohereChatRequest, **kwargs: Any) -> Any: response = self.client.chat( **chat_request.model_dump(exclude={"stream", "file_ids", "agent_id"}), - **kwargs, ) yield to_dict(response) @@ -83,7 +82,6 @@ async def invoke_chat_stream( ) -> AsyncGenerator[Any, Any]: stream = self.client.chat_stream( **chat_request.model_dump(exclude={"stream", "file_ids", "agent_id"}), - **kwargs, ) for event in stream: @@ -94,5 +92,5 @@ async def invoke_rerank( self, query: str, documents: List[Dict[str, Any]], **kwargs: Any ) -> Any: return self.client.rerank( - query=query, documents=documents, model=DEFAULT_RERANK_MODEL, **kwargs + query=query, documents=documents, model=DEFAULT_RERANK_MODEL ) diff --git a/src/backend/model_deployments/sagemaker.py b/src/backend/model_deployments/sagemaker.py index e2e824caac..f5358f731d 100644 --- a/src/backend/model_deployments/sagemaker.py +++ b/src/backend/model_deployments/sagemaker.py @@ -80,7 +80,7 @@ def is_available(cls) -> bool: @collect_metrics_chat_stream async def invoke_chat_stream( - self, chat_request: CohereChatRequest, **kwargs: Any + self, chat_request: CohereChatRequest ) -> AsyncGenerator[Any, Any]: # Create the payload for the request json_params = { @@ -100,9 +100,7 @@ async def invoke_chat_stream( stream_event["index"] = index yield stream_event - async def invoke_rerank( - self, query: str, documents: List[Dict[str, Any]], **kwargs: Any - ) -> Any: + async def invoke_rerank(self, query: str, documents: List[Dict[str, Any]]) -> Any: return None # This class iterates through each line of Sagemaker's response diff --git a/src/backend/model_deployments/single_container.py b/src/backend/model_deployments/single_container.py index 4beefb4504..8a0c6c5cd9 100644 --- a/src/backend/model_deployments/single_container.py +++ b/src/backend/model_deployments/single_container.py @@ -48,33 +48,29 @@ def is_available(cls) -> bool: return all([os.environ.get(var) is not None for var in SC_ENV_VARS]) @collect_metrics_chat - async def invoke_chat(self, chat_request: CohereChatRequest, **kwargs: Any) -> Any: + async def invoke_chat(self, chat_request: CohereChatRequest) -> Any: response = self.client.chat( **chat_request.model_dump( exclude={"stream", "file_ids", "model", "agent_id"} ), - **kwargs, ) yield to_dict(response) @collect_metrics_chat_stream async def invoke_chat_stream( - self, chat_request: CohereChatRequest, **kwargs: Any + self, chat_request: CohereChatRequest ) -> AsyncGenerator[Any, Any]: stream = self.client.chat_stream( **chat_request.model_dump( exclude={"stream", "file_ids", "model", "agent_id"} ), - **kwargs, ) for event in stream: yield to_dict(event) @collect_metrics_rerank - async def invoke_rerank( - self, query: str, documents: List[Dict[str, Any]], **kwargs: Any - ) -> Any: + async def invoke_rerank(self, query: str, documents: List[Dict[str, Any]]) -> Any: return self.client.rerank( - query=query, documents=documents, model=DEFAULT_RERANK_MODEL, **kwargs + query=query, documents=documents, model=DEFAULT_RERANK_MODEL ) diff --git a/src/backend/services/metrics.py b/src/backend/services/metrics.py index 4e3e67d219..b07570eaeb 100644 --- a/src/backend/services/metrics.py +++ b/src/backend/services/metrics.py @@ -372,9 +372,9 @@ def initialize_sdk_metrics_data( MetricsData( id=str(uuid.uuid4()), message_type=message_type, - trace_id=kwargs.pop("trace_id", None), - user_id=kwargs.pop("user_id", None), - assistant_id=kwargs.pop("agent_id", None), + trace_id=kwargs.get("trace_id", None), + user_id=kwargs.get("user_id", None), + assistant_id=kwargs.get("agent_id", None), model=chat_request.model if chat_request else None, ), kwargs, diff --git a/src/backend/tools/__init__.py b/src/backend/tools/__init__.py index c16ec0a6a2..f4c70701d6 100644 --- a/src/backend/tools/__init__.py +++ b/src/backend/tools/__init__.py @@ -4,6 +4,7 @@ from backend.tools.lang_chain import LangChainVectorDBRetriever, LangChainWikiRetriever from backend.tools.python_interpreter import PythonInterpreter from backend.tools.tavily import TavilyInternetSearch +from backend.tools.web_scrape import WebScrapeTool __all__ = [ "Calculator", @@ -15,4 +16,5 @@ "SearchFileTool", "GoogleDrive", "GoogleDriveAuth", + "WebScrapeTool", ] diff --git a/src/backend/tools/tavily.py b/src/backend/tools/tavily.py index 3e93c82bbf..9fe4e792af 100644 --- a/src/backend/tools/tavily.py +++ b/src/backend/tools/tavily.py @@ -1,9 +1,11 @@ +import copy import os from typing import Any, Dict, List from langchain_community.tools.tavily_search import TavilySearchResults from tavily import TavilyClient +from backend.model_deployments.base import BaseDeployment from backend.tools.base import BaseTool @@ -13,6 +15,7 @@ class TavilyInternetSearch(BaseTool): def __init__(self): self.client = TavilyClient(api_key=self.TAVILY_API_KEY) + self.num_results = 6 @classmethod def is_available(cls) -> bool: @@ -20,19 +23,79 @@ def is_available(cls) -> bool: async def call(self, parameters: dict, **kwargs: Any) -> List[Dict[str, Any]]: query = parameters.get("query", "") - content = self.client.search(query=query, search_depth="advanced") + result = self.client.search( + query=query, search_depth="advanced", include_raw_content=True + ) - if "results" not in content: + if "results" not in result: return [] + expanded = [] + for result in result["results"]: + # Append original search result + expanded.append(result) + + # Get other snippets + snippets = result["raw_content"].split("\n") + for snippet in snippets: + if result["content"] != snippet: + if len(snippet.split()) <= 10: + continue # Skip snippets with less than 10 words + + new_result = { + "url": result["url"], + "title": result["title"], + "content": snippet.strip(), + } + expanded.append(new_result) + + reranked_results = await self.rerank_page_snippets( + query, expanded, model=kwargs.get("model_deployment"), **kwargs + ) + return [ - { - "url": result["url"], - "text": result["content"], - } - for result in content["results"] + {"url": result["url"], "text": result["content"]} + for result in reranked_results ] + async def rerank_page_snippets( + self, + query: str, + snippets: List[Dict[str, Any]], + model: BaseDeployment, + **kwargs: Any, + ) -> List[Dict[str, Any]]: + if len(snippets) == 0: + return [] + + rerank_batch_size = 500 + relevance_scores = [None for _ in range(len(snippets))] + for batch_start in range(0, len(snippets), rerank_batch_size): + snippet_batch = snippets[batch_start : batch_start + rerank_batch_size] + batch_output = await model.invoke_rerank( + query=query, + documents=[ + f"{snippet['title']} {snippet['content']}" + for snippet in snippet_batch + ], + **kwargs, + ) + for b in batch_output.get("results", []): + index = b.get("index", None) + relevance_score = b.get("relevance_score", None) + if index is not None: + relevance_scores[batch_start + index] = relevance_score + + reranked, seen_urls = [], [] + for _, result in sorted( + zip(relevance_scores, snippets), key=lambda x: x[0], reverse=True + ): + if result["url"] not in seen_urls: + seen_urls.append(result["url"]) + reranked.append(result) + + return reranked[: self.num_results] + def to_langchain_tool(self) -> TavilySearchResults: internet_search = TavilySearchResults() internet_search.name = "internet_search" diff --git a/src/backend/tools/web_scrape.py b/src/backend/tools/web_scrape.py new file mode 100644 index 0000000000..22c23072ac --- /dev/null +++ b/src/backend/tools/web_scrape.py @@ -0,0 +1,41 @@ +from typing import Any, Dict, List + +from bs4 import BeautifulSoup +from requests import get + +from backend.tools.base import BaseTool + + +class WebScrapeTool(BaseTool): + NAME = "web_scrape" + + @classmethod + def is_available(cls) -> bool: + return True + + async def call(self, parameters: dict, **kwargs: Any) -> List[Dict[str, Any]]: + url = parameters.get("url") + + response = get(url) + if not response.ok: + error_message = f"HTTP {response.status_code} {response.reason}" + return [ + ( + { + "text": f"Cannot open and scrape URL {url}, Error: {error_message}", + "url": url, + } + ) + ] + + soup = BeautifulSoup(response.text, "html.parser") + text = soup.get_text(separator="\n") + + return [ + ( + { + "text": text, + "url": url, + } + ) + ]