From 2b42895ff78b965077852fc800620567e74bbebf Mon Sep 17 00:00:00 2001 From: Dinesh Yeduguru Date: Tue, 17 Dec 2024 14:00:29 -0800 Subject: [PATCH] migrate tools and make tool runtime discover --- .../meta_reference/meta_reference.py | 95 +++++-- .../tool_runtime/meta_reference/tools/base.py | 35 +++ .../meta_reference/tools/bing_search.py | 67 +++++ .../meta_reference/tools/brave_search.py | 101 +++++++ .../meta_reference/tools/code_interpreter.py | 53 ++++ .../tools/ipython_tool/__init__.py | 5 + .../tools/ipython_tool/code_env_prefix.py | 133 +++++++++ .../tools/ipython_tool/code_execution.py | 256 ++++++++++++++++++ .../ipython_tool/matplotlib_custom_backend.py | 90 ++++++ .../tools/ipython_tool/utils.py | 21 ++ .../meta_reference/tools/photogen.py | 38 +++ .../meta_reference/tools/tavily_search.py | 42 +++ .../meta_reference/tools/wolfram_alpha.py | 96 +++++++ 13 files changed, 1007 insertions(+), 25 deletions(-) create mode 100644 llama_stack/providers/inline/tool_runtime/meta_reference/tools/base.py create mode 100644 llama_stack/providers/inline/tool_runtime/meta_reference/tools/bing_search.py create mode 100644 llama_stack/providers/inline/tool_runtime/meta_reference/tools/brave_search.py create mode 100644 llama_stack/providers/inline/tool_runtime/meta_reference/tools/code_interpreter.py create mode 100644 llama_stack/providers/inline/tool_runtime/meta_reference/tools/ipython_tool/__init__.py create mode 100644 llama_stack/providers/inline/tool_runtime/meta_reference/tools/ipython_tool/code_env_prefix.py create mode 100644 llama_stack/providers/inline/tool_runtime/meta_reference/tools/ipython_tool/code_execution.py create mode 100644 llama_stack/providers/inline/tool_runtime/meta_reference/tools/ipython_tool/matplotlib_custom_backend.py create mode 100644 llama_stack/providers/inline/tool_runtime/meta_reference/tools/ipython_tool/utils.py create mode 100644 llama_stack/providers/inline/tool_runtime/meta_reference/tools/photogen.py create mode 100644 llama_stack/providers/inline/tool_runtime/meta_reference/tools/tavily_search.py create mode 100644 llama_stack/providers/inline/tool_runtime/meta_reference/tools/wolfram_alpha.py diff --git a/llama_stack/providers/inline/tool_runtime/meta_reference/meta_reference.py b/llama_stack/providers/inline/tool_runtime/meta_reference/meta_reference.py index 2fea15435d..89efafecd5 100644 --- a/llama_stack/providers/inline/tool_runtime/meta_reference/meta_reference.py +++ b/llama_stack/providers/inline/tool_runtime/meta_reference/meta_reference.py @@ -4,55 +4,100 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import importlib import logging -from enum import Enum -from typing import Any, Dict - -import llama_stack.providers.inline.tool_runtime.meta_reference.builtins as builtins +import pkgutil +from typing import Any, Dict, Optional, Type from llama_stack.apis.tools import Tool, ToolRuntime from llama_stack.distribution.request_headers import NeedsRequestProviderData from llama_stack.providers.datatypes import ToolsProtocolPrivate +from llama_stack.providers.inline.tool_runtime.meta_reference.tools.base import BaseTool from .config import MetaReferenceToolRuntimeConfig logger = logging.getLogger(__name__) -class ToolType(Enum): - bing_search = "bing_search" - brave_search = "brave_search" - tavily_search = "tavily_search" - print_tool = "print_tool" - - class MetaReferenceToolRuntimeImpl( ToolsProtocolPrivate, ToolRuntime, NeedsRequestProviderData ): def __init__(self, config: MetaReferenceToolRuntimeConfig): self.config = config + self.tools: Dict[str, Type[BaseTool]] = {} + self.tool_instances: Dict[str, BaseTool] = {} + self._discover_tools() + + def _discover_tools(self): + # Import all tools from the tools package + tools_package = "llama_stack.providers.inline.tool_runtime.tools" + package = importlib.import_module(tools_package) + + for _, name, _ in pkgutil.iter_modules(package.__path__): + module = importlib.import_module(f"{tools_package}.{name}") + for attr_name in dir(module): + attr = getattr(module, attr_name) + if ( + isinstance(attr, type) + and issubclass(attr, BaseTool) + and attr != BaseTool + ): + self.tools[attr.tool_id()] = attr + + async def _create_tool_instance( + self, tool_id: str, tool_def: Optional[Tool] = None + ) -> BaseTool: + """Create a new tool instance with proper configuration""" + if tool_id not in self.tools: + raise ValueError(f"Tool {tool_id} not found in available tools") + + tool_class = self.tools[tool_id] + + # Get tool definition if not provided + if tool_def is None: + tool_def = await self.tool_store.get_tool(tool_id) + + # Build configuration + config = dict(tool_def.provider_metadata.get("config") or {}) + if tool_class.requires_api_key: + config["api_key"] = self._get_api_key() + + return tool_class(config=config) async def initialize(self): pass async def register_tool(self, tool: Tool): - print(f"registering tool {tool.identifier}") - if tool.provider_resource_id not in ToolType.__members__: - raise ValueError( - f"Tool {tool.identifier} not a supported tool by Meta Reference" - ) + if tool.identifier not in self.tools: + raise ValueError(f"Tool {tool.identifier} not found in available tools") - async def unregister_tool(self, tool_id: str) -> None: - raise NotImplementedError("Meta Reference does not support unregistering tools") + # Validate provider_metadata against tool's config type if specified + tool_class = self.tools[tool.identifier] + config_type = tool_class.get_provider_config_type() + if ( + config_type + and tool.provider_metadata + and tool.provider_metadata.get("config") + ): + config_type(**tool.provider_metadata.get("config")) + + self.tool_instances[tool.identifier] = await self._create_tool_instance( + tool.identifier, tool + ) async def invoke_tool(self, tool_id: str, args: Dict[str, Any]) -> Any: - tool = await self.tool_store.get_tool(tool_id) - if args.get("__api_key__") is not None: - logger.warning( - "__api_key__ is a reserved argument for this tool: {tool_id}" - ) - args["__api_key__"] = self._get_api_key() - return await getattr(builtins, tool.provider_resource_id)(**args) + if tool_id not in self.tools: + raise ValueError(f"Tool {tool_id} not found") + + if tool_id not in self.tool_instances: + self.tool_instances[tool_id] = await self._create_tool_instance(tool_id) + + return await self.tool_instances[tool_id].execute(**args) + + async def unregister_tool(self, tool_id: str) -> None: + if tool_id in self.tool_instances: + del self.tool_instances[tool_id] + raise NotImplementedError("Meta Reference does not support unregistering tools") def _get_api_key(self) -> str: provider_data = self.get_request_provider_data() diff --git a/llama_stack/providers/inline/tool_runtime/meta_reference/tools/base.py b/llama_stack/providers/inline/tool_runtime/meta_reference/tools/base.py new file mode 100644 index 0000000000..79e20f85ec --- /dev/null +++ b/llama_stack/providers/inline/tool_runtime/meta_reference/tools/base.py @@ -0,0 +1,35 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from abc import ABC, abstractmethod +from typing import Any, Dict, Optional, Type, TypeVar + +T = TypeVar("T") + + +class BaseTool(ABC): + """Base class for all tools""" + + requires_api_key: bool = False + + def __init__(self, config: Optional[Dict[str, Any]] = None): + self.config = config or {} + + @classmethod + @abstractmethod + def tool_id(cls) -> str: + """Unique identifier for the tool""" + pass + + @abstractmethod + async def execute(self, **kwargs) -> Any: + """Execute the tool with given arguments""" + pass + + @classmethod + def get_provider_config_type(cls) -> Optional[Type[T]]: + """Override to specify a Pydantic model for tool configuration""" + return None diff --git a/llama_stack/providers/inline/tool_runtime/meta_reference/tools/bing_search.py b/llama_stack/providers/inline/tool_runtime/meta_reference/tools/bing_search.py new file mode 100644 index 0000000000..0ccc598462 --- /dev/null +++ b/llama_stack/providers/inline/tool_runtime/meta_reference/tools/bing_search.py @@ -0,0 +1,67 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import json +from typing import List + +import requests + +from llama_stack.providers.inline.tool_runtime.meta_reference.tools.base import BaseTool +from pydantic import BaseModel + + +class BingSearchConfig(BaseModel): + api_key: str + max_results: int = 5 + + +class BingSearchTool(BaseTool): + requires_api_key: bool = True + + @classmethod + def tool_id(cls) -> str: + return "bing_search" + + @classmethod + def get_provider_config_type(cls): + return BingSearchConfig + + async def execute(self, query: str) -> List[dict]: + config = BingSearchConfig(**self.config) + url = "https://api.bing.microsoft.com/v7.0/search" + headers = { + "Ocp-Apim-Subscription-Key": config.api_key, + } + params = { + "count": config.max_results, + "textDecorations": True, + "textFormat": "HTML", + "q": query, + } + + response = requests.get(url=url, params=params, headers=headers) + response.raise_for_status() + return json.dumps(self._clean_response(response.json())) + + def _clean_response(self, search_response): + clean_response = [] + query = search_response["queryContext"]["originalQuery"] + if "webPages" in search_response: + pages = search_response["webPages"]["value"] + for p in pages: + selected_keys = {"name", "url", "snippet"} + clean_response.append( + {k: v for k, v in p.items() if k in selected_keys} + ) + if "news" in search_response: + clean_news = [] + news = search_response["news"]["value"] + for n in news: + selected_keys = {"name", "url", "description"} + clean_news.append({k: v for k, v in n.items() if k in selected_keys}) + clean_response.append(clean_news) + + return {"query": query, "results": clean_response} diff --git a/llama_stack/providers/inline/tool_runtime/meta_reference/tools/brave_search.py b/llama_stack/providers/inline/tool_runtime/meta_reference/tools/brave_search.py new file mode 100644 index 0000000000..efcf5d8289 --- /dev/null +++ b/llama_stack/providers/inline/tool_runtime/meta_reference/tools/brave_search.py @@ -0,0 +1,101 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import List + +import requests + +from llama_stack.providers.inline.tool_runtime.meta_reference.tools.base import BaseTool +from pydantic import BaseModel + + +class BraveSearchConfig(BaseModel): + api_key: str + max_results: int = 3 + + +class BraveSearchTool(BaseTool): + requires_api_key: bool = True + + @classmethod + def tool_id(cls) -> str: + return "brave_search" + + @classmethod + def get_provider_config_type(cls): + return BraveSearchConfig + + async def execute(self, query: str) -> List[dict]: + config = BraveSearchConfig(**self.config) + url = "https://api.search.brave.com/res/v1/web/search" + headers = { + "X-Subscription-Token": config.api_key, + "Accept-Encoding": "gzip", + "Accept": "application/json", + } + payload = {"q": query} + response = requests.get(url=url, params=payload, headers=headers) + response.raise_for_status() + return self._clean_brave_response(response.json(), config.max_results) + + def _clean_brave_response(self, search_response, top_k=3): + query = None + clean_response = [] + if "query" in search_response: + if "original" in search_response["query"]: + query = search_response["query"]["original"] + if "mixed" in search_response: + mixed_results = search_response["mixed"] + for m in mixed_results["main"][:top_k]: + r_type = m["type"] + results = search_response[r_type]["results"] + cleaned = self._clean_result_by_type(r_type, results, m.get("index")) + clean_response.append(cleaned) + + return {"query": query, "results": clean_response} + + def _clean_result_by_type(self, r_type, results, idx=None): + type_cleaners = { + "web": ( + ["type", "title", "url", "description", "date", "extra_snippets"], + lambda x: x[idx], + ), + "faq": (["type", "question", "answer", "title", "url"], lambda x: x), + "infobox": ( + ["type", "title", "url", "description", "long_desc"], + lambda x: x[idx], + ), + "videos": (["type", "url", "title", "description", "date"], lambda x: x), + "locations": ( + [ + "type", + "title", + "url", + "description", + "coordinates", + "postal_address", + "contact", + "rating", + "distance", + "zoom_level", + ], + lambda x: x, + ), + "news": (["type", "title", "url", "description"], lambda x: x), + } + + if r_type not in type_cleaners: + return [] + + selected_keys, result_selector = type_cleaners[r_type] + results = result_selector(results) + + if isinstance(results, list): + return [ + {k: v for k, v in item.items() if k in selected_keys} + for item in results + ] + return {k: v for k, v in results.items() if k in selected_keys} diff --git a/llama_stack/providers/inline/tool_runtime/meta_reference/tools/code_interpreter.py b/llama_stack/providers/inline/tool_runtime/meta_reference/tools/code_interpreter.py new file mode 100644 index 0000000000..8be7816bf2 --- /dev/null +++ b/llama_stack/providers/inline/tool_runtime/meta_reference/tools/code_interpreter.py @@ -0,0 +1,53 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import tempfile +from typing import Dict + +from llama_stack.providers.inline.tool_runtime.meta_reference.tools.base import BaseTool +from pydantic import BaseModel + +from .ipython_tool.code_execution import ( + CodeExecutionContext, + CodeExecutionRequest, + CodeExecutor, +) + + +class CodeInterpreterConfig(BaseModel): + matplotlib_dump_dir: str = None + + +class CodeInterpreterTool(BaseTool): + + @classmethod + def tool_id(cls) -> str: + return "code_interpreter" + + @classmethod + def get_provider_config_type(cls): + return CodeInterpreterConfig + + async def execute(self, code: str) -> Dict: + config = CodeInterpreterConfig(**self.config) + + ctx = CodeExecutionContext( + matplotlib_dump_dir=config.matplotlib_dump_dir or tempfile.mkdtemp(), + ) + executor = CodeExecutor(ctx) + + req = CodeExecutionRequest(scripts=[code]) + result = executor.execute(req) + + response = {"status": result["process_status"], "output": []} + + for out_type in ["stdout", "stderr"]: + if result[out_type]: + response["output"].append( + {"type": out_type, "content": result[out_type]} + ) + + return response diff --git a/llama_stack/providers/inline/tool_runtime/meta_reference/tools/ipython_tool/__init__.py b/llama_stack/providers/inline/tool_runtime/meta_reference/tools/ipython_tool/__init__.py new file mode 100644 index 0000000000..756f351d88 --- /dev/null +++ b/llama_stack/providers/inline/tool_runtime/meta_reference/tools/ipython_tool/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. diff --git a/llama_stack/providers/inline/tool_runtime/meta_reference/tools/ipython_tool/code_env_prefix.py b/llama_stack/providers/inline/tool_runtime/meta_reference/tools/ipython_tool/code_env_prefix.py new file mode 100644 index 0000000000..10f64ec94f --- /dev/null +++ b/llama_stack/providers/inline/tool_runtime/meta_reference/tools/ipython_tool/code_env_prefix.py @@ -0,0 +1,133 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import errno + +# Disabling potentially dangerous functions +import os as _os +from functools import partial + +os_funcs_to_disable = [ + "kill", + "system", + "putenv", + "remove", + "removedirs", + "rmdir", + "fchdir", + "setuid", + "fork", + "forkpty", + "killpg", + "rename", + "renames", + "truncate", + "replace", + # "unlink", # Commenting as this was blocking matpltlib from rendering plots correctly + "fchmod", + "fchown", + "chmod", + "chown", + "chroot", + "fchdir", + "lchflags", + "lchmod", + "lchown", + "chdir", +] + + +def call_not_allowed(*args, **kwargs): + raise OSError(errno.EPERM, "Call are not permitted in this environment") + + +for func_name in os_funcs_to_disable: + if hasattr(_os, func_name): + setattr(_os, func_name, partial(call_not_allowed, _func_name=f"os.{func_name}")) + +import shutil as _shutil + +for func_name in ["rmtree", "move", "chown"]: + if hasattr(_shutil, func_name): + setattr( + _shutil, + func_name, + partial(call_not_allowed, _func_name=f"shutil.{func_name}"), + ) + +import subprocess as _subprocess + + +def popen_not_allowed(*args, **kwargs): + raise _subprocess.CalledProcessError( + -1, + args[0] if args else "unknown", + stderr="subprocess.Popen is not allowed in this environment", + ) + + +_subprocess.Popen = popen_not_allowed + + +import atexit as _atexit +import builtins as _builtins +import io as _io +import json as _json +import sys as _sys + +# NB! The following "unused" imports crucial, make sure not not to remove +# them with linters - they're used in code_execution.py +from contextlib import ( # noqa + contextmanager as _contextmanager, + redirect_stderr as _redirect_stderr, + redirect_stdout as _redirect_stdout, +) +from multiprocessing.connection import Connection as _Connection + +# Mangle imports to avoid polluting model execution namespace. + +_IO_SINK = _io.StringIO() +_NETWORK_TIMEOUT = 5 +_NETWORK_CONNECTIONS = None + + +def _open_connections(): + global _NETWORK_CONNECTIONS + if _NETWORK_CONNECTIONS is not None: + # Ensure connections only opened once. + return _NETWORK_CONNECTIONS + req_w_fd, resp_r_fd = _sys.argv[1], _sys.argv[2] + req_con = _Connection(int(req_w_fd), readable=False) + resp_con = _Connection(int(resp_r_fd), writable=False) + _NETWORK_CONNECTIONS = (req_con, resp_con) + return _NETWORK_CONNECTIONS + + +_builtins._open_connections = _open_connections + + +@_atexit.register +def _close_connections(): + global _NETWORK_CONNECTIONS + if _NETWORK_CONNECTIONS is None: + return + for con in _NETWORK_CONNECTIONS: + con.close() + del _NETWORK_CONNECTIONS + + +def _network_call(request): + # NOTE: We communicate with the parent process in json, encoded + # in raw bytes. We do this because native send/recv methods use + # pickle which involves execution of arbitrary code. + _open_connections() + req_con, resp_con = _NETWORK_CONNECTIONS + + req_con.send_bytes(_json.dumps(request).encode("utf-8")) + if resp_con.poll(timeout=_NETWORK_TIMEOUT) is None: + raise Exception(f"Network request timed out: {_json.dumps(request)}") + else: + return _json.loads(resp_con.recv_bytes().decode("utf-8")) diff --git a/llama_stack/providers/inline/tool_runtime/meta_reference/tools/ipython_tool/code_execution.py b/llama_stack/providers/inline/tool_runtime/meta_reference/tools/ipython_tool/code_execution.py new file mode 100644 index 0000000000..fa2e367e58 --- /dev/null +++ b/llama_stack/providers/inline/tool_runtime/meta_reference/tools/ipython_tool/code_execution.py @@ -0,0 +1,256 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import base64 +import json +import multiprocessing +import os +import re +import subprocess +import sys +import tempfile +import textwrap +import time +from dataclasses import dataclass +from datetime import datetime +from io import BytesIO +from pathlib import Path +from typing import List + +from PIL import Image + +from .utils import get_code_env_prefix + +TOOLS_ATTACHMENT_KEY = "__tools_attachment__" +TOOLS_ATTACHMENT_KEY_REGEX = re.compile(r"__tools_attachment__=(\{.*?\})") + +DIRNAME = Path(__file__).parent + +CODE_EXEC_TIMEOUT = 20 +CODE_ENV_PREFIX = get_code_env_prefix() + +STDOUTERR_SINK_WRAPPER_TEMPLATE = """\ +with _redirect_stdout(_IO_SINK), _redirect_stderr(_IO_SINK): +{code}\ +""" + +TRYEXCEPT_WRAPPER_TEMPLATE = """\ +try: +{code} +except: + pass\ +""" + + +def generate_bwrap_command(bind_dirs: List[str]) -> str: + """ + Generate the bwrap command string for binding all + directories in the current directory read-only. + """ + bwrap_args = "" + bwrap_args += "--ro-bind / / " + # Add the --dev flag to mount device files + bwrap_args += "--dev /dev " + for d in bind_dirs: + bwrap_args += f"--bind {d} {d} " + + # Add the --unshare-all flag to isolate the sandbox from the rest of the system + bwrap_args += "--unshare-all " + # Add the --die-with-parent flag to ensure the child process dies when bwrap's parent dies + bwrap_args += "--die-with-parent " + return bwrap_args + + +@dataclass +class CodeExecutionContext: + matplotlib_dump_dir: str + use_proxy: bool = False + + +@dataclass +class CodeExecutionRequest: + scripts: List[str] + only_last_cell_stdouterr: bool = True + only_last_cell_fail: bool = True + seed: int = 0 + strip_fpaths_in_stderr: bool = True + + +class CodeExecutor: + def __init__(self, context: CodeExecutionContext): + self.context = context + + def execute(self, req: CodeExecutionRequest) -> dict: + scripts = req.scripts + for i in range(len(scripts) - 1): + if req.only_last_cell_stdouterr: + scripts[i] = STDOUTERR_SINK_WRAPPER_TEMPLATE.format( + code=textwrap.indent(scripts[i], " " * 4) + ) + if req.only_last_cell_fail: + scripts[i] = TRYEXCEPT_WRAPPER_TEMPLATE.format( + code=textwrap.indent(scripts[i], " " * 4) + ) + + # Seeds prefix: + seed = req.seed + seeds_prefix = f"""\ +def _set_seeds(): + import random + random.seed({seed}) + import numpy as np + np.random.seed({seed}) +_set_seeds()\ +""" + + script = "\n\n".join([seeds_prefix] + [CODE_ENV_PREFIX] + scripts) + with tempfile.TemporaryDirectory() as dpath: + bwrap_prefix = "bwrap " + generate_bwrap_command(bind_dirs=[dpath]) + cmd = [*bwrap_prefix.split(), sys.executable, "-c", script] + code_fpath = os.path.join(dpath, "code.py") + with open(code_fpath, "w") as f: + f.write(script) + + try: + python_path = os.environ.get("PYTHONPATH", "") + env = dict( + os.environ, + PYTHONHASHSEED=str(seed), + MPLCONFIGDIR=dpath, + MPLBACKEND="module://matplotlib_custom_backend", + PYTHONPATH=f"{DIRNAME}:{python_path}", + ) + stdout, stderr, returncode = do_subprocess( + cmd=cmd, + env=env, + ctx=self.context, + ) + + stderr = stderr.strip() + if req.strip_fpaths_in_stderr: + pattern = r'File "([^"]+)", line (\d+)' + stderr = re.sub(pattern, r"line \2", stderr) + + return { + "process_status": "completed", + "returncode": returncode, + "stdout": stdout.strip(), + "stderr": stderr, + } + + except subprocess.TimeoutExpired: + return { + "process_status": "timeout", + "stdout": "Timed out", + "stderr": "Timed out", + } + + except Exception as e: + return { + "process_status": "error", + "error_type": type(e).__name__, + "stderr": str(e), + "stdout": str(e), + } + + +def process_matplotlib_response(response, matplotlib_dump_dir: str): + image_data = response["image_data"] + # Convert the base64 string to a bytes object + images = [base64.b64decode(d["image_base64"]) for d in image_data] + # Create a list of PIL images from the bytes objects + images = [Image.open(BytesIO(img)) for img in images] + # Create a list of image paths + image_paths = [] + for i, img in enumerate(images): + # create new directory for each day to better organize data: + dump_dname = datetime.today().strftime("%Y-%m-%d") + dump_dpath = Path(matplotlib_dump_dir, dump_dname) + dump_dpath.mkdir(parents=True, exist_ok=True) + # save image into a file + dump_fname = f"matplotlib_{str(time.time()).replace('.', '_')}_{i}.png" + dump_fpath = dump_dpath / dump_fname + img.save(dump_fpath, "PNG") + image_paths.append(str(dump_fpath)) + + # this is kind of convoluted, we send back this response to the subprocess which + # prints it out + info = { + "filepath": str(image_paths[-1]), + "mimetype": "image/png", + } + return f"{TOOLS_ATTACHMENT_KEY}={json.dumps(info)}" + + +def execute_subprocess_request(request, ctx: CodeExecutionContext): + "Route requests from the subprocess (via network Pipes) to the internet/tools." + if request["type"] == "matplotlib": + return process_matplotlib_response(request, ctx.matplotlib_dump_dir) + else: + raise Exception(f'Unrecognised network request type: {request["type"]}') + + +def do_subprocess(*, cmd: list, env: dict, ctx: CodeExecutionContext): + # Create Pipes to be used for any external tool/network requests. + req_r, req_w = multiprocessing.Pipe(duplex=False) + resp_r, resp_w = multiprocessing.Pipe(duplex=False) + + cmd += [str(req_w.fileno()), str(resp_r.fileno())] + proc = subprocess.Popen( + cmd, + pass_fds=(req_w.fileno(), resp_r.fileno()), + text=True, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + close_fds=True, + env=env, + ) + + # Close unnecessary fds. + req_w.close() + resp_r.close() + + pipe_close = False + done_read = False + start = time.monotonic() + while proc.poll() is None and not pipe_close: + if req_r.poll(0.1): + # NB: Python pipe semantics for poll and recv mean that + # poll() returns True is a pipe is closed. + # CF old school PEP from '09 + # https://bugs.python.org/issue5573 + try: + request = json.loads(req_r.recv_bytes().decode("utf-8")) + response = execute_subprocess_request(request, ctx) + + resp_w.send_bytes(json.dumps(response).encode("utf-8")) + except EOFError: + # The request pipe is closed - set a marker to exit + # after the next attempt at reading stdout/stderr. + pipe_close = True + + try: + # If lots has been printed, pipe might be full but + # proc cannot exit until all the stdout/stderr + # been written/read. + stdout, stderr = proc.communicate(timeout=0.3) + done_read = True + except subprocess.TimeoutExpired: + # The program has not terminated. Ignore it, there + # may be more network/tool requests. + continue + if time.monotonic() - start > CODE_EXEC_TIMEOUT: + proc.terminate() + raise subprocess.TimeoutExpired(cmd, CODE_EXEC_TIMEOUT) + + if not done_read: + # Solve race condition where process terminates before + # we hit the while loop. + stdout, stderr = proc.communicate(timeout=0.3) + + resp_w.close() + req_r.close() + return stdout, stderr, proc.returncode diff --git a/llama_stack/providers/inline/tool_runtime/meta_reference/tools/ipython_tool/matplotlib_custom_backend.py b/llama_stack/providers/inline/tool_runtime/meta_reference/tools/ipython_tool/matplotlib_custom_backend.py new file mode 100644 index 0000000000..7fec08cf24 --- /dev/null +++ b/llama_stack/providers/inline/tool_runtime/meta_reference/tools/ipython_tool/matplotlib_custom_backend.py @@ -0,0 +1,90 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +""" +A custom Matplotlib backend that overrides the show method to return image bytes. +""" + +import base64 +import io +import json as _json +import logging + +import matplotlib +from matplotlib.backend_bases import FigureManagerBase + +# Import necessary components from Matplotlib +from matplotlib.backends.backend_agg import FigureCanvasAgg + +log = logging.getLogger(__name__) + + +class CustomFigureCanvas(FigureCanvasAgg): + def show(self): + # Save the figure to a BytesIO object + buf = io.BytesIO() + self.print_png(buf) + image_bytes = buf.getvalue() + buf.close() + return image_bytes + + +class CustomFigureManager(FigureManagerBase): + def __init__(self, canvas, num): + super().__init__(canvas, num) + + +# Mimic module initialization that integrates with the Matplotlib backend system +def _create_figure_manager(num, *args, **kwargs): + """ + Create a custom figure manager instance. + """ + FigureClass = kwargs.pop("FigureClass", None) # noqa: N806 + if FigureClass is None: + from matplotlib.figure import Figure + + FigureClass = Figure # noqa: N806 + fig = FigureClass(*args, **kwargs) + canvas = CustomFigureCanvas(fig) + manager = CustomFigureManager(canvas, num) + return manager + + +def show(): + """ + Handle all figures and potentially return their images as bytes. + + This function iterates over all figures registered with the custom backend, + renders them as images in bytes format, and could return a list of bytes objects, + one for each figure, or handle them as needed. + """ + image_data = [] + for manager in matplotlib._pylab_helpers.Gcf.get_all_fig_managers(): + # Get the figure from the manager + fig = manager.canvas.figure + buf = io.BytesIO() # Create a buffer for the figure + fig.savefig(buf, format="png") # Save the figure to the buffer in PNG format + buf.seek(0) # Go to the beginning of the buffer + image_bytes = buf.getvalue() # Retrieve bytes value + image_base64 = base64.b64encode(image_bytes).decode("utf-8") + image_data.append({"image_base64": image_base64}) + buf.close() + + req_con, resp_con = _open_connections() + + _json_dump = _json.dumps( + { + "type": "matplotlib", + "image_data": image_data, + } + ) + req_con.send_bytes(_json_dump.encode("utf-8")) + resp = _json.loads(resp_con.recv_bytes().decode("utf-8")) + log.info(resp) + + +FigureCanvas = CustomFigureCanvas +FigureManager = CustomFigureManager diff --git a/llama_stack/providers/inline/tool_runtime/meta_reference/tools/ipython_tool/utils.py b/llama_stack/providers/inline/tool_runtime/meta_reference/tools/ipython_tool/utils.py new file mode 100644 index 0000000000..d6f539a39f --- /dev/null +++ b/llama_stack/providers/inline/tool_runtime/meta_reference/tools/ipython_tool/utils.py @@ -0,0 +1,21 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import os + +DIR = os.path.dirname(os.path.realpath(__file__)) +CODE_ENV_PREFIX_FILE = os.path.join(DIR, "code_env_prefix.py") +CODE_ENV_PREFIX = None + + +def get_code_env_prefix() -> str: + global CODE_ENV_PREFIX + + if CODE_ENV_PREFIX is None: + with open(CODE_ENV_PREFIX_FILE, "r") as f: + CODE_ENV_PREFIX = f.read() + + return CODE_ENV_PREFIX diff --git a/llama_stack/providers/inline/tool_runtime/meta_reference/tools/photogen.py b/llama_stack/providers/inline/tool_runtime/meta_reference/tools/photogen.py new file mode 100644 index 0000000000..96351f390e --- /dev/null +++ b/llama_stack/providers/inline/tool_runtime/meta_reference/tools/photogen.py @@ -0,0 +1,38 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Dict + +from llama_stack.providers.inline.tool_runtime.meta_reference.tools.base import BaseTool +from pydantic import BaseModel + + +class PhotogenConfig(BaseModel): + dump_dir: str + + +class PhotogenTool(BaseTool): + + @classmethod + def tool_id(cls) -> str: + return "photogen" + + @classmethod + def get_provider_config_type(cls): + return PhotogenConfig + + async def execute(self, query: str) -> Dict: + config = PhotogenConfig(**self.config) + """ + Implement this to give the model an ability to generate images. + + Return: + info = { + "filepath": str(image_filepath), + "mimetype": "image/png", + } + """ + raise NotImplementedError() diff --git a/llama_stack/providers/inline/tool_runtime/meta_reference/tools/tavily_search.py b/llama_stack/providers/inline/tool_runtime/meta_reference/tools/tavily_search.py new file mode 100644 index 0000000000..f6030b985d --- /dev/null +++ b/llama_stack/providers/inline/tool_runtime/meta_reference/tools/tavily_search.py @@ -0,0 +1,42 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import List + +import requests + +from llama_stack.providers.inline.tool_runtime.meta_reference.tools.base import BaseTool +from pydantic import BaseModel + + +class TavilySearchConfig(BaseModel): + api_key: str + max_results: int = 3 + + +class TavilySearchTool(BaseTool): + requires_api_key: bool = True + + @classmethod + def tool_id(cls) -> str: + return "tavily_search" + + @classmethod + def get_provider_config_type(cls): + return TavilySearchConfig + + async def execute(self, query: str) -> List[dict]: + config = TavilySearchConfig(**self.config) + response = requests.post( + "https://api.tavily.com/search", + json={"api_key": config.api_key, "query": query}, + ) + response.raise_for_status() + search_response = response.json() + return { + "query": search_response["query"], + "results": search_response["results"][: config.max_results], + } diff --git a/llama_stack/providers/inline/tool_runtime/meta_reference/tools/wolfram_alpha.py b/llama_stack/providers/inline/tool_runtime/meta_reference/tools/wolfram_alpha.py new file mode 100644 index 0000000000..5f891a1c9b --- /dev/null +++ b/llama_stack/providers/inline/tool_runtime/meta_reference/tools/wolfram_alpha.py @@ -0,0 +1,96 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import json +from typing import Dict + +import requests + +from llama_stack.providers.inline.tool_runtime.meta_reference.tools.base import BaseTool +from pydantic import BaseModel + + +class WolframAlphaConfig(BaseModel): + api_key: str + + +class WolframAlphaTool(BaseTool): + requires_api_key: bool = True + + @classmethod + def tool_id(cls) -> str: + return "wolfram_alpha" + + @classmethod + def get_provider_config_type(cls): + return WolframAlphaConfig + + async def execute(self, query: str) -> Dict: + config = WolframAlphaConfig(**self.config) + url = "https://api.wolframalpha.com/v2/query" + params = { + "input": query, + "appid": config.api_key, + "format": "plaintext", + "output": "json", + } + response = requests.get(url, params=params) + response.raise_for_status() + return json.dumps(self._clean_wolfram_alpha_response(response.json())) + + def _clean_wolfram_alpha_response(self, wa_response): + remove = { + "queryresult": [ + "datatypes", + "error", + "timedout", + "timedoutpods", + "numpods", + "timing", + "parsetiming", + "parsetimedout", + "recalculate", + "id", + "host", + "server", + "related", + "version", + { + "pods": [ + "scanner", + "id", + "error", + "expressiontypes", + "states", + "infos", + "position", + "numsubpods", + ] + }, + "assumptions", + ], + } + + result = wa_response.copy() + for main_key, to_remove in remove.items(): + if main_key not in result: + continue + + for item in to_remove: + if isinstance(item, dict): + for sub_key, sub_items in item.items(): + if sub_key == "pods": + pods = result[main_key].get(sub_key, []) + for i, pod in enumerate(pods): + if pod.get("title") == "Result": + pods = pods[: i + 1] + break + for remove_key in sub_items: + pod.pop(remove_key, None) + else: + result[main_key].pop(item, None) + + return result