Skip to content

Commit

Permalink
backend: web scraping in tools + fix metrics data use in tools (#393)
Browse files Browse the repository at this point in the history
* all changes

* lint

* code review changes

* error msg

* update tool descriptions

* fix
  • Loading branch information
scott-cohere authored Jul 11, 2024
1 parent ffc711d commit e0cf5ff
Show file tree
Hide file tree
Showing 12 changed files with 158 additions and 40 deletions.
1 change: 1 addition & 0 deletions src/backend/chat/custom/custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,7 @@ async def call_tools(self, chat_history, deployment_model, **kwargs: Any):
session=kwargs.get("session"),
model_deployment=deployment_model,
user_id=kwargs.get("user_id"),
trace_id=kwargs.get("trace_id"),
agent_id=kwargs.get("agent_id"),
)

Expand Down
5 changes: 4 additions & 1 deletion src/backend/config/routers.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,11 @@ class RouterName(StrEnum):
],
},
RouterName.TOOL: {
"default": [],
"default": [
Depends(get_session),
],
"auth": [
Depends(get_session),
Depends(validate_authorization),
],
},
Expand Down
24 changes: 24 additions & 0 deletions src/backend/config/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
ReadFileTool,
SearchFileTool,
TavilyInternetSearch,
WebScrapeTool,
)

"""
Expand All @@ -35,6 +36,7 @@ class ToolName(StrEnum):
Calculator = Calculator.NAME
Tavily_Internet_Search = TavilyInternetSearch.NAME
Google_Drive = GoogleDrive.NAME
Web_Scrape = WebScrapeTool.NAME


ALL_TOOLS = {
Expand Down Expand Up @@ -157,6 +159,28 @@ class ToolName(StrEnum):
category=Category.DataLoader,
description="Returns a list of relevant document snippets for the user's google drive.",
),
ToolName.Web_Scrape: ManagedTool(
name=ToolName.Web_Scrape,
display_name="Web Scrape",
implementation=WebScrapeTool,
parameter_definitions={
"url": {
"description": "The url to scrape.",
"type": "str",
"required": True,
},
"query": {
"description": "The query to use to select the most relevant passages to return. Using an empty string will return the passages in the order they appear on the webpage",
"type": "str",
"required": False,
},
},
is_visible=True,
is_available=WebScrapeTool.is_available(),
error_message="WebScrapeTool not available.",
category=Category.DataLoader,
description="Scrape and returns the textual contents of a webpage as a list of passages for a given url.",
),
}


Expand Down
10 changes: 3 additions & 7 deletions src/backend/model_deployments/azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,26 +63,22 @@ def is_available(cls) -> bool:
return all([os.environ.get(var) is not None for var in AZURE_ENV_VARS])

@collect_metrics_chat
async def invoke_chat(self, chat_request: CohereChatRequest, **kwargs: Any) -> Any:
async def invoke_chat(self, chat_request: CohereChatRequest) -> Any:
response = self.client.chat(
**chat_request.model_dump(exclude={"stream", "file_ids", "agent_id"}),
**kwargs,
)
yield to_dict(response)

@collect_metrics_chat_stream
async def invoke_chat_stream(
self, chat_request: CohereChatRequest, **kwargs: Any
self, chat_request: CohereChatRequest
) -> AsyncGenerator[Any, Any]:
stream = self.client.chat_stream(
**chat_request.model_dump(exclude={"stream", "file_ids", "agent_id"}),
**kwargs,
)

for event in stream:
yield to_dict(event)

async def invoke_rerank(
self, query: str, documents: List[Dict[str, Any]], **kwargs: Any
) -> Any:
async def invoke_rerank(self, query: str, documents: List[Dict[str, Any]]) -> Any:
return None
10 changes: 3 additions & 7 deletions src/backend/model_deployments/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,21 +64,20 @@ def is_available(cls) -> bool:
return all([os.environ.get(var) is not None for var in BEDROCK_ENV_VARS])

@collect_metrics_chat
async def invoke_chat(self, chat_request: CohereChatRequest, **kwargs: Any) -> Any:
async def invoke_chat(self, chat_request: CohereChatRequest) -> Any:
# bedrock accepts a subset of the chat request fields
bedrock_chat_req = chat_request.model_dump(
exclude={"tools", "conversation_id", "model", "stream"}, exclude_none=True
)

response = self.client.chat(
**bedrock_chat_req,
**kwargs,
)
yield to_dict(response)

@collect_metrics_chat_stream
async def invoke_chat_stream(
self, chat_request: CohereChatRequest, **kwargs: Any
self, chat_request: CohereChatRequest
) -> AsyncGenerator[Any, Any]:
# bedrock accepts a subset of the chat request fields
bedrock_chat_req = chat_request.model_dump(
Expand All @@ -87,12 +86,9 @@ async def invoke_chat_stream(

stream = self.client.chat_stream(
**bedrock_chat_req,
**kwargs,
)
for event in stream:
yield to_dict(event)

async def invoke_rerank(
self, query: str, documents: List[Dict[str, Any]], **kwargs: Any
) -> Any:
async def invoke_rerank(self, query: str, documents: List[Dict[str, Any]]) -> Any:
return None
4 changes: 1 addition & 3 deletions src/backend/model_deployments/cohere_platform.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,6 @@ def is_available(cls) -> bool:
async def invoke_chat(self, chat_request: CohereChatRequest, **kwargs: Any) -> Any:
response = self.client.chat(
**chat_request.model_dump(exclude={"stream", "file_ids", "agent_id"}),
**kwargs,
)
yield to_dict(response)

Expand All @@ -83,7 +82,6 @@ async def invoke_chat_stream(
) -> AsyncGenerator[Any, Any]:
stream = self.client.chat_stream(
**chat_request.model_dump(exclude={"stream", "file_ids", "agent_id"}),
**kwargs,
)

for event in stream:
Expand All @@ -94,5 +92,5 @@ async def invoke_rerank(
self, query: str, documents: List[Dict[str, Any]], **kwargs: Any
) -> Any:
return self.client.rerank(
query=query, documents=documents, model=DEFAULT_RERANK_MODEL, **kwargs
query=query, documents=documents, model=DEFAULT_RERANK_MODEL
)
6 changes: 2 additions & 4 deletions src/backend/model_deployments/sagemaker.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def is_available(cls) -> bool:

@collect_metrics_chat_stream
async def invoke_chat_stream(
self, chat_request: CohereChatRequest, **kwargs: Any
self, chat_request: CohereChatRequest
) -> AsyncGenerator[Any, Any]:
# Create the payload for the request
json_params = {
Expand All @@ -100,9 +100,7 @@ async def invoke_chat_stream(
stream_event["index"] = index
yield stream_event

async def invoke_rerank(
self, query: str, documents: List[Dict[str, Any]], **kwargs: Any
) -> Any:
async def invoke_rerank(self, query: str, documents: List[Dict[str, Any]]) -> Any:
return None

# This class iterates through each line of Sagemaker's response
Expand Down
12 changes: 4 additions & 8 deletions src/backend/model_deployments/single_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,33 +48,29 @@ def is_available(cls) -> bool:
return all([os.environ.get(var) is not None for var in SC_ENV_VARS])

@collect_metrics_chat
async def invoke_chat(self, chat_request: CohereChatRequest, **kwargs: Any) -> Any:
async def invoke_chat(self, chat_request: CohereChatRequest) -> Any:
response = self.client.chat(
**chat_request.model_dump(
exclude={"stream", "file_ids", "model", "agent_id"}
),
**kwargs,
)
yield to_dict(response)

@collect_metrics_chat_stream
async def invoke_chat_stream(
self, chat_request: CohereChatRequest, **kwargs: Any
self, chat_request: CohereChatRequest
) -> AsyncGenerator[Any, Any]:
stream = self.client.chat_stream(
**chat_request.model_dump(
exclude={"stream", "file_ids", "model", "agent_id"}
),
**kwargs,
)

for event in stream:
yield to_dict(event)

@collect_metrics_rerank
async def invoke_rerank(
self, query: str, documents: List[Dict[str, Any]], **kwargs: Any
) -> Any:
async def invoke_rerank(self, query: str, documents: List[Dict[str, Any]]) -> Any:
return self.client.rerank(
query=query, documents=documents, model=DEFAULT_RERANK_MODEL, **kwargs
query=query, documents=documents, model=DEFAULT_RERANK_MODEL
)
6 changes: 3 additions & 3 deletions src/backend/services/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,9 +372,9 @@ def initialize_sdk_metrics_data(
MetricsData(
id=str(uuid.uuid4()),
message_type=message_type,
trace_id=kwargs.pop("trace_id", None),
user_id=kwargs.pop("user_id", None),
assistant_id=kwargs.pop("agent_id", None),
trace_id=kwargs.get("trace_id", None),
user_id=kwargs.get("user_id", None),
assistant_id=kwargs.get("agent_id", None),
model=chat_request.model if chat_request else None,
),
kwargs,
Expand Down
2 changes: 2 additions & 0 deletions src/backend/tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from backend.tools.lang_chain import LangChainVectorDBRetriever, LangChainWikiRetriever
from backend.tools.python_interpreter import PythonInterpreter
from backend.tools.tavily import TavilyInternetSearch
from backend.tools.web_scrape import WebScrapeTool

__all__ = [
"Calculator",
Expand All @@ -15,4 +16,5 @@
"SearchFileTool",
"GoogleDrive",
"GoogleDriveAuth",
"WebScrapeTool",
]
77 changes: 70 additions & 7 deletions src/backend/tools/tavily.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import copy
import os
from typing import Any, Dict, List

from langchain_community.tools.tavily_search import TavilySearchResults
from tavily import TavilyClient

from backend.model_deployments.base import BaseDeployment
from backend.tools.base import BaseTool


Expand All @@ -13,26 +15,87 @@ class TavilyInternetSearch(BaseTool):

def __init__(self):
self.client = TavilyClient(api_key=self.TAVILY_API_KEY)
self.num_results = 6

@classmethod
def is_available(cls) -> bool:
return cls.TAVILY_API_KEY is not None

async def call(self, parameters: dict, **kwargs: Any) -> List[Dict[str, Any]]:
query = parameters.get("query", "")
content = self.client.search(query=query, search_depth="advanced")
result = self.client.search(
query=query, search_depth="advanced", include_raw_content=True
)

if "results" not in content:
if "results" not in result:
return []

expanded = []
for result in result["results"]:
# Append original search result
expanded.append(result)

# Get other snippets
snippets = result["raw_content"].split("\n")
for snippet in snippets:
if result["content"] != snippet:
if len(snippet.split()) <= 10:
continue # Skip snippets with less than 10 words

new_result = {
"url": result["url"],
"title": result["title"],
"content": snippet.strip(),
}
expanded.append(new_result)

reranked_results = await self.rerank_page_snippets(
query, expanded, model=kwargs.get("model_deployment"), **kwargs
)

return [
{
"url": result["url"],
"text": result["content"],
}
for result in content["results"]
{"url": result["url"], "text": result["content"]}
for result in reranked_results
]

async def rerank_page_snippets(
self,
query: str,
snippets: List[Dict[str, Any]],
model: BaseDeployment,
**kwargs: Any,
) -> List[Dict[str, Any]]:
if len(snippets) == 0:
return []

rerank_batch_size = 500
relevance_scores = [None for _ in range(len(snippets))]
for batch_start in range(0, len(snippets), rerank_batch_size):
snippet_batch = snippets[batch_start : batch_start + rerank_batch_size]
batch_output = await model.invoke_rerank(
query=query,
documents=[
f"{snippet['title']} {snippet['content']}"
for snippet in snippet_batch
],
**kwargs,
)
for b in batch_output.get("results", []):
index = b.get("index", None)
relevance_score = b.get("relevance_score", None)
if index is not None:
relevance_scores[batch_start + index] = relevance_score

reranked, seen_urls = [], []
for _, result in sorted(
zip(relevance_scores, snippets), key=lambda x: x[0], reverse=True
):
if result["url"] not in seen_urls:
seen_urls.append(result["url"])
reranked.append(result)

return reranked[: self.num_results]

def to_langchain_tool(self) -> TavilySearchResults:
internet_search = TavilySearchResults()
internet_search.name = "internet_search"
Expand Down
41 changes: 41 additions & 0 deletions src/backend/tools/web_scrape.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from typing import Any, Dict, List

from bs4 import BeautifulSoup
from requests import get

from backend.tools.base import BaseTool


class WebScrapeTool(BaseTool):
NAME = "web_scrape"

@classmethod
def is_available(cls) -> bool:
return True

async def call(self, parameters: dict, **kwargs: Any) -> List[Dict[str, Any]]:
url = parameters.get("url")

response = get(url)
if not response.ok:
error_message = f"HTTP {response.status_code} {response.reason}"
return [
(
{
"text": f"Cannot open and scrape URL {url}, Error: {error_message}",
"url": url,
}
)
]

soup = BeautifulSoup(response.text, "html.parser")
text = soup.get_text(separator="\n")

return [
(
{
"text": text,
"url": url,
}
)
]

0 comments on commit e0cf5ff

Please sign in to comment.