Skip to content

Commit

Permalink
migrate tools and make tool runtime discover
Browse files Browse the repository at this point in the history
  • Loading branch information
dineshyv committed Dec 17, 2024
1 parent 8fdf024 commit 2b42895
Show file tree
Hide file tree
Showing 13 changed files with 1,007 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,55 +4,100 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

import importlib
import logging
from enum import Enum
from typing import Any, Dict

import llama_stack.providers.inline.tool_runtime.meta_reference.builtins as builtins
import pkgutil
from typing import Any, Dict, Optional, Type

from llama_stack.apis.tools import Tool, ToolRuntime
from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.providers.datatypes import ToolsProtocolPrivate
from llama_stack.providers.inline.tool_runtime.meta_reference.tools.base import BaseTool

from .config import MetaReferenceToolRuntimeConfig

logger = logging.getLogger(__name__)


class ToolType(Enum):
bing_search = "bing_search"
brave_search = "brave_search"
tavily_search = "tavily_search"
print_tool = "print_tool"


class MetaReferenceToolRuntimeImpl(
ToolsProtocolPrivate, ToolRuntime, NeedsRequestProviderData
):
def __init__(self, config: MetaReferenceToolRuntimeConfig):
self.config = config
self.tools: Dict[str, Type[BaseTool]] = {}
self.tool_instances: Dict[str, BaseTool] = {}
self._discover_tools()

def _discover_tools(self):
# Import all tools from the tools package
tools_package = "llama_stack.providers.inline.tool_runtime.tools"
package = importlib.import_module(tools_package)

for _, name, _ in pkgutil.iter_modules(package.__path__):
module = importlib.import_module(f"{tools_package}.{name}")
for attr_name in dir(module):
attr = getattr(module, attr_name)
if (
isinstance(attr, type)
and issubclass(attr, BaseTool)
and attr != BaseTool
):
self.tools[attr.tool_id()] = attr

async def _create_tool_instance(
self, tool_id: str, tool_def: Optional[Tool] = None
) -> BaseTool:
"""Create a new tool instance with proper configuration"""
if tool_id not in self.tools:
raise ValueError(f"Tool {tool_id} not found in available tools")

tool_class = self.tools[tool_id]

# Get tool definition if not provided
if tool_def is None:
tool_def = await self.tool_store.get_tool(tool_id)

# Build configuration
config = dict(tool_def.provider_metadata.get("config") or {})
if tool_class.requires_api_key:
config["api_key"] = self._get_api_key()

return tool_class(config=config)

async def initialize(self):
pass

async def register_tool(self, tool: Tool):
print(f"registering tool {tool.identifier}")
if tool.provider_resource_id not in ToolType.__members__:
raise ValueError(
f"Tool {tool.identifier} not a supported tool by Meta Reference"
)
if tool.identifier not in self.tools:
raise ValueError(f"Tool {tool.identifier} not found in available tools")

async def unregister_tool(self, tool_id: str) -> None:
raise NotImplementedError("Meta Reference does not support unregistering tools")
# Validate provider_metadata against tool's config type if specified
tool_class = self.tools[tool.identifier]
config_type = tool_class.get_provider_config_type()
if (
config_type
and tool.provider_metadata
and tool.provider_metadata.get("config")
):
config_type(**tool.provider_metadata.get("config"))

self.tool_instances[tool.identifier] = await self._create_tool_instance(
tool.identifier, tool
)

async def invoke_tool(self, tool_id: str, args: Dict[str, Any]) -> Any:
tool = await self.tool_store.get_tool(tool_id)
if args.get("__api_key__") is not None:
logger.warning(
"__api_key__ is a reserved argument for this tool: {tool_id}"
)
args["__api_key__"] = self._get_api_key()
return await getattr(builtins, tool.provider_resource_id)(**args)
if tool_id not in self.tools:
raise ValueError(f"Tool {tool_id} not found")

if tool_id not in self.tool_instances:
self.tool_instances[tool_id] = await self._create_tool_instance(tool_id)

return await self.tool_instances[tool_id].execute(**args)

async def unregister_tool(self, tool_id: str) -> None:
if tool_id in self.tool_instances:
del self.tool_instances[tool_id]
raise NotImplementedError("Meta Reference does not support unregistering tools")

def _get_api_key(self) -> str:
provider_data = self.get_request_provider_data()
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

from abc import ABC, abstractmethod
from typing import Any, Dict, Optional, Type, TypeVar

T = TypeVar("T")


class BaseTool(ABC):
"""Base class for all tools"""

requires_api_key: bool = False

def __init__(self, config: Optional[Dict[str, Any]] = None):
self.config = config or {}

@classmethod
@abstractmethod
def tool_id(cls) -> str:
"""Unique identifier for the tool"""
pass

@abstractmethod
async def execute(self, **kwargs) -> Any:
"""Execute the tool with given arguments"""
pass

@classmethod
def get_provider_config_type(cls) -> Optional[Type[T]]:
"""Override to specify a Pydantic model for tool configuration"""
return None
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

import json
from typing import List

import requests

from llama_stack.providers.inline.tool_runtime.meta_reference.tools.base import BaseTool
from pydantic import BaseModel


class BingSearchConfig(BaseModel):
api_key: str
max_results: int = 5


class BingSearchTool(BaseTool):
requires_api_key: bool = True

@classmethod
def tool_id(cls) -> str:
return "bing_search"

@classmethod
def get_provider_config_type(cls):
return BingSearchConfig

async def execute(self, query: str) -> List[dict]:
config = BingSearchConfig(**self.config)
url = "https://api.bing.microsoft.com/v7.0/search"
headers = {
"Ocp-Apim-Subscription-Key": config.api_key,
}
params = {
"count": config.max_results,
"textDecorations": True,
"textFormat": "HTML",
"q": query,
}

response = requests.get(url=url, params=params, headers=headers)
response.raise_for_status()
return json.dumps(self._clean_response(response.json()))

def _clean_response(self, search_response):
clean_response = []
query = search_response["queryContext"]["originalQuery"]
if "webPages" in search_response:
pages = search_response["webPages"]["value"]
for p in pages:
selected_keys = {"name", "url", "snippet"}
clean_response.append(
{k: v for k, v in p.items() if k in selected_keys}
)
if "news" in search_response:
clean_news = []
news = search_response["news"]["value"]
for n in news:
selected_keys = {"name", "url", "description"}
clean_news.append({k: v for k, v in n.items() if k in selected_keys})
clean_response.append(clean_news)

return {"query": query, "results": clean_response}
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

from typing import List

import requests

from llama_stack.providers.inline.tool_runtime.meta_reference.tools.base import BaseTool
from pydantic import BaseModel


class BraveSearchConfig(BaseModel):
api_key: str
max_results: int = 3


class BraveSearchTool(BaseTool):
requires_api_key: bool = True

@classmethod
def tool_id(cls) -> str:
return "brave_search"

@classmethod
def get_provider_config_type(cls):
return BraveSearchConfig

async def execute(self, query: str) -> List[dict]:
config = BraveSearchConfig(**self.config)
url = "https://api.search.brave.com/res/v1/web/search"
headers = {
"X-Subscription-Token": config.api_key,
"Accept-Encoding": "gzip",
"Accept": "application/json",
}
payload = {"q": query}
response = requests.get(url=url, params=payload, headers=headers)
response.raise_for_status()
return self._clean_brave_response(response.json(), config.max_results)

def _clean_brave_response(self, search_response, top_k=3):
query = None
clean_response = []
if "query" in search_response:
if "original" in search_response["query"]:
query = search_response["query"]["original"]
if "mixed" in search_response:
mixed_results = search_response["mixed"]
for m in mixed_results["main"][:top_k]:
r_type = m["type"]
results = search_response[r_type]["results"]
cleaned = self._clean_result_by_type(r_type, results, m.get("index"))
clean_response.append(cleaned)

return {"query": query, "results": clean_response}

def _clean_result_by_type(self, r_type, results, idx=None):
type_cleaners = {
"web": (
["type", "title", "url", "description", "date", "extra_snippets"],
lambda x: x[idx],
),
"faq": (["type", "question", "answer", "title", "url"], lambda x: x),
"infobox": (
["type", "title", "url", "description", "long_desc"],
lambda x: x[idx],
),
"videos": (["type", "url", "title", "description", "date"], lambda x: x),
"locations": (
[
"type",
"title",
"url",
"description",
"coordinates",
"postal_address",
"contact",
"rating",
"distance",
"zoom_level",
],
lambda x: x,
),
"news": (["type", "title", "url", "description"], lambda x: x),
}

if r_type not in type_cleaners:
return []

selected_keys, result_selector = type_cleaners[r_type]
results = result_selector(results)

if isinstance(results, list):
return [
{k: v for k, v in item.items() if k in selected_keys}
for item in results
]
return {k: v for k, v in results.items() if k in selected_keys}
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

import tempfile
from typing import Dict

from llama_stack.providers.inline.tool_runtime.meta_reference.tools.base import BaseTool
from pydantic import BaseModel

from .ipython_tool.code_execution import (
CodeExecutionContext,
CodeExecutionRequest,
CodeExecutor,
)


class CodeInterpreterConfig(BaseModel):
matplotlib_dump_dir: str = None


class CodeInterpreterTool(BaseTool):

@classmethod
def tool_id(cls) -> str:
return "code_interpreter"

@classmethod
def get_provider_config_type(cls):
return CodeInterpreterConfig

async def execute(self, code: str) -> Dict:
config = CodeInterpreterConfig(**self.config)

ctx = CodeExecutionContext(
matplotlib_dump_dir=config.matplotlib_dump_dir or tempfile.mkdtemp(),
)
executor = CodeExecutor(ctx)

req = CodeExecutionRequest(scripts=[code])
result = executor.execute(req)

response = {"status": result["process_status"], "output": []}

for out_type in ["stdout", "stderr"]:
if result[out_type]:
response["output"].append(
{"type": out_type, "content": result[out_type]}
)

return response
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
Loading

0 comments on commit 2b42895

Please sign in to comment.