diff --git a/llama_stack/apis/tools/tools.py b/llama_stack/apis/tools/tools.py index 239375b11..62d7af6ce 100644 --- a/llama_stack/apis/tools/tools.py +++ b/llama_stack/apis/tools/tools.py @@ -49,6 +49,10 @@ class Tool(Resource): ) +class ToolStore(Protocol): + def get_tool(self, identifier: str) -> Tool: ... + + @runtime_checkable @trace_protocol class Tools(Protocol): @@ -87,6 +91,8 @@ async def unregister_tool(self, tool_id: str) -> None: @runtime_checkable @trace_protocol class ToolRuntime(Protocol): + tool_store: ToolStore + @webmethod(route="/tool-runtime/invoke", method="POST") async def invoke_tool(self, tool_id: str, args: Dict[str, Any]) -> Any: """Run a tool with the given arguments""" diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py index 4c9ef01ba..814d5029a 100644 --- a/llama_stack/distribution/routers/routing_tables.py +++ b/llama_stack/distribution/routers/routing_tables.py @@ -111,6 +111,8 @@ async def add_objects( await add_objects(scoring_functions, pid, ScoringFn) elif api == Api.eval: p.eval_task_store = self + elif api == Api.tool_runtime: + p.tool_store = self async def shutdown(self) -> None: for p in self.impls_by_provider_id.values(): @@ -132,6 +134,8 @@ def apiname_object(): return ("Scoring", "scoring_function") elif isinstance(self, EvalTasksRoutingTable): return ("Eval", "eval_task") + elif isinstance(self, ToolsRoutingTable): + return ("Tools", "tool") else: raise ValueError("Unknown routing table type") diff --git a/llama_stack/providers/inline/tool_runtime/meta_reference/__init__.py b/llama_stack/providers/inline/tool_runtime/meta_reference/__init__.py index f7d52c1f0..81a086bfc 100644 --- a/llama_stack/providers/inline/tool_runtime/meta_reference/__init__.py +++ b/llama_stack/providers/inline/tool_runtime/meta_reference/__init__.py @@ -4,11 +4,21 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +from typing import Any, Dict + +from pydantic import BaseModel + from .config import MetaReferenceToolRuntimeConfig from .meta_reference import MetaReferenceToolRuntimeImpl -async def get_provider_impl(config: MetaReferenceToolRuntimeConfig, _deps): +class MetaReferenceProviderDataValidator(BaseModel): + api_key: str + + +async def get_provider_impl( + config: MetaReferenceToolRuntimeConfig, _deps: Dict[str, Any] +): impl = MetaReferenceToolRuntimeImpl(config) await impl.initialize() return impl diff --git a/llama_stack/providers/inline/tool_runtime/meta_reference/builtins.py b/llama_stack/providers/inline/tool_runtime/meta_reference/builtins.py new file mode 100644 index 000000000..494ff093e --- /dev/null +++ b/llama_stack/providers/inline/tool_runtime/meta_reference/builtins.py @@ -0,0 +1,165 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import json +import logging + +import requests + +logger = logging.getLogger(__name__) + + +async def bing_search(query: str, __api_key__: str, top_k: int = 3, **kwargs) -> str: + url = "https://api.bing.microsoft.com/v7.0/search" + headers = { + "Ocp-Apim-Subscription-Key": __api_key__, + } + params = { + "count": top_k, + "textDecorations": True, + "textFormat": "HTML", + "q": query, + } + + response = requests.get(url=url, params=params, headers=headers) + response.raise_for_status() + clean = _bing_clean_response(response.json()) + return json.dumps(clean) + + +def _bing_clean_response(search_response): + clean_response = [] + query = search_response["queryContext"]["originalQuery"] + if "webPages" in search_response: + pages = search_response["webPages"]["value"] + for p in pages: + selected_keys = {"name", "url", "snippet"} + clean_response.append({k: v for k, v in p.items() if k in selected_keys}) + if "news" in search_response: + clean_news = [] + news = search_response["news"]["value"] + for n in news: + selected_keys = {"name", "url", "description"} + clean_news.append({k: v for k, v in n.items() if k in selected_keys}) + + clean_response.append(clean_news) + + return {"query": query, "top_k": clean_response} + + +async def brave_search(query: str, __api_key__: str) -> str: + url = "https://api.search.brave.com/res/v1/web/search" + headers = { + "X-Subscription-Token": __api_key__, + "Accept-Encoding": "gzip", + "Accept": "application/json", + } + payload = {"q": query} + response = requests.get(url=url, params=payload, headers=headers) + return json.dumps(_clean_brave_response(response.json())) + + +def _clean_brave_response(search_response, top_k=3): + query = None + clean_response = [] + if "query" in search_response: + if "original" in search_response["query"]: + query = search_response["query"]["original"] + if "mixed" in search_response: + mixed_results = search_response["mixed"] + for m in mixed_results["main"][:top_k]: + r_type = m["type"] + results = search_response[r_type]["results"] + if r_type == "web": + # For web data - add a single output from the search + idx = m["index"] + selected_keys = [ + "type", + "title", + "url", + "description", + "date", + "extra_snippets", + ] + cleaned = {k: v for k, v in results[idx].items() if k in selected_keys} + elif r_type == "faq": + # For faw data - take a list of all the questions & answers + selected_keys = ["type", "question", "answer", "title", "url"] + cleaned = [] + for q in results: + cleaned.append({k: v for k, v in q.items() if k in selected_keys}) + elif r_type == "infobox": + idx = m["index"] + selected_keys = [ + "type", + "title", + "url", + "description", + "long_desc", + ] + cleaned = {k: v for k, v in results[idx].items() if k in selected_keys} + elif r_type == "videos": + selected_keys = [ + "type", + "url", + "title", + "description", + "date", + ] + cleaned = [] + for q in results: + cleaned.append({k: v for k, v in q.items() if k in selected_keys}) + elif r_type == "locations": + # For faw data - take a list of all the questions & answers + selected_keys = [ + "type", + "title", + "url", + "description", + "coordinates", + "postal_address", + "contact", + "rating", + "distance", + "zoom_level", + ] + cleaned = [] + for q in results: + cleaned.append({k: v for k, v in q.items() if k in selected_keys}) + elif r_type == "news": + # For faw data - take a list of all the questions & answers + selected_keys = [ + "type", + "title", + "url", + "description", + ] + cleaned = [] + for q in results: + cleaned.append({k: v for k, v in q.items() if k in selected_keys}) + else: + cleaned = [] + + clean_response.append(cleaned) + + return {"query": query, "top_k": clean_response} + + +async def tavily_search(query: str, __api_key__: str) -> str: + response = requests.post( + "https://api.tavily.com/search", + json={"api_key": __api_key__, "query": query}, + ) + return json.dumps(_clean_tavily_response(response.json())) + + +def _clean_tavily_response(search_response, top_k=3): + return {"query": search_response["query"], "top_k": search_response["results"]} + + +async def print_tool(query: str, __api_key__: str) -> str: + logger.info(f"print_tool called with query: {query} and api_key: {__api_key__}") + return json.dumps({"result": "success"}) diff --git a/llama_stack/providers/inline/tool_runtime/meta_reference/meta_reference.py b/llama_stack/providers/inline/tool_runtime/meta_reference/meta_reference.py index 8e4718d85..2fea15435 100644 --- a/llama_stack/providers/inline/tool_runtime/meta_reference/meta_reference.py +++ b/llama_stack/providers/inline/tool_runtime/meta_reference/meta_reference.py @@ -4,15 +4,31 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import logging +from enum import Enum from typing import Any, Dict +import llama_stack.providers.inline.tool_runtime.meta_reference.builtins as builtins + from llama_stack.apis.tools import Tool, ToolRuntime +from llama_stack.distribution.request_headers import NeedsRequestProviderData from llama_stack.providers.datatypes import ToolsProtocolPrivate from .config import MetaReferenceToolRuntimeConfig +logger = logging.getLogger(__name__) + + +class ToolType(Enum): + bing_search = "bing_search" + brave_search = "brave_search" + tavily_search = "tavily_search" + print_tool = "print_tool" + -class MetaReferenceToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime): +class MetaReferenceToolRuntimeImpl( + ToolsProtocolPrivate, ToolRuntime, NeedsRequestProviderData +): def __init__(self, config: MetaReferenceToolRuntimeConfig): self.config = config @@ -21,10 +37,27 @@ async def initialize(self): async def register_tool(self, tool: Tool): print(f"registering tool {tool.identifier}") - pass + if tool.provider_resource_id not in ToolType.__members__: + raise ValueError( + f"Tool {tool.identifier} not a supported tool by Meta Reference" + ) async def unregister_tool(self, tool_id: str) -> None: - pass + raise NotImplementedError("Meta Reference does not support unregistering tools") async def invoke_tool(self, tool_id: str, args: Dict[str, Any]) -> Any: - pass + tool = await self.tool_store.get_tool(tool_id) + if args.get("__api_key__") is not None: + logger.warning( + "__api_key__ is a reserved argument for this tool: {tool_id}" + ) + args["__api_key__"] = self._get_api_key() + return await getattr(builtins, tool.provider_resource_id)(**args) + + def _get_api_key(self) -> str: + provider_data = self.get_request_provider_data() + if provider_data is None or not provider_data.api_key: + raise ValueError( + 'Pass Search provider\'s API Key in the header X-LlamaStack-ProviderData as { "api_key": }' + ) + return provider_data.api_key diff --git a/llama_stack/providers/registry/tool_runtime.py b/llama_stack/providers/registry/tool_runtime.py index c0e7a3d1b..de52bf7c8 100644 --- a/llama_stack/providers/registry/tool_runtime.py +++ b/llama_stack/providers/registry/tool_runtime.py @@ -17,5 +17,6 @@ def available_providers() -> List[ProviderSpec]: pip_packages=[], module="llama_stack.providers.inline.tool_runtime.meta_reference", config_class="llama_stack.providers.inline.tool_runtime.meta_reference.MetaReferenceToolRuntimeConfig", + provider_data_validator="llama_stack.providers.inline.tool_runtime.meta_reference.MetaReferenceProviderDataValidator", ), ]