diff --git a/docs/pydoc/config/tools_api.yml b/docs/pydoc/config/tools_api.yml index d3f953087f..3050e6c587 100644 --- a/docs/pydoc/config/tools_api.yml +++ b/docs/pydoc/config/tools_api.yml @@ -2,7 +2,7 @@ loaders: - type: haystack_pydoc_tools.loaders.CustomPythonLoader search_path: [../../../haystack/tools] modules: - ["tool", "from_function"] + ["tool", "from_function", "component_tool"] ignore_when_discovered: ["__init__"] processors: - type: filter diff --git a/haystack/tools/__init__.py b/haystack/tools/__init__.py index 4601ac71c6..ccb274d49a 100644 --- a/haystack/tools/__init__.py +++ b/haystack/tools/__init__.py @@ -2,7 +2,17 @@ # # SPDX-License-Identifier: Apache-2.0 +# ruff: noqa: I001 (ignore import order as we need to import Tool before ComponentTool) from haystack.tools.from_function import create_tool_from_function, tool from haystack.tools.tool import Tool, _check_duplicate_tool_names, deserialize_tools_inplace +from haystack.tools.component_tool import ComponentTool -__all__ = ["Tool", "_check_duplicate_tool_names", "deserialize_tools_inplace", "create_tool_from_function", "tool"] + +__all__ = [ + "Tool", + "_check_duplicate_tool_names", + "deserialize_tools_inplace", + "create_tool_from_function", + "tool", + "ComponentTool", +] diff --git a/haystack/tools/component_tool.py b/haystack/tools/component_tool.py new file mode 100644 index 0000000000..cc77ceca01 --- /dev/null +++ b/haystack/tools/component_tool.py @@ -0,0 +1,330 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +from dataclasses import fields, is_dataclass +from inspect import getdoc +from typing import Any, Callable, Dict, Optional, Union, get_args, get_origin + +from pydantic import TypeAdapter + +from haystack import logging +from haystack.core.component import Component +from haystack.core.serialization import ( + component_from_dict, + component_to_dict, + generate_qualified_class_name, + import_class_by_name, +) +from haystack.lazy_imports import LazyImport +from haystack.tools import Tool +from haystack.tools.errors import SchemaGenerationError + +with LazyImport(message="Run 'pip install docstring-parser'") as docstring_parser_import: + from docstring_parser import parse + + +logger = logging.getLogger(__name__) + + +class ComponentTool(Tool): + """ + A Tool that wraps Haystack components, allowing them to be used as tools by LLMs. + + ComponentTool automatically generates LLM-compatible tool schemas from component input sockets, + which are derived from the component's `run` method signature and type hints. + + + Key features: + - Automatic LLM tool calling schema generation from component input sockets + - Type conversion and validation for component inputs + - Support for types: + - Dataclasses + - Lists of dataclasses + - Basic types (str, int, float, bool, dict) + - Lists of basic types + - Automatic name generation from component class name + - Description extraction from component docstrings + + To use ComponentTool, you first need a Haystack component - either an existing one or a new one you create. + You can create a ComponentTool from the component by passing the component to the ComponentTool constructor. + Below is an example of creating a ComponentTool from an existing SerperDevWebSearch component. + + ```python + from haystack import component, Pipeline + from haystack.tools import ComponentTool + from haystack.components.websearch import SerperDevWebSearch + from haystack.utils import Secret + from haystack.components.tools.tool_invoker import ToolInvoker + from haystack.components.generators.chat import OpenAIChatGenerator + from haystack.dataclasses import ChatMessage + + # Create a SerperDev search component + search = SerperDevWebSearch(api_key=Secret.from_env_var("SERPERDEV_API_KEY"), top_k=3) + + # Create a tool from the component + tool = ComponentTool( + component=search, + name="web_search", # Optional: defaults to "serper_dev_web_search" + description="Search the web for current information on any topic" # Optional: defaults to component docstring + ) + + # Create pipeline with OpenAIChatGenerator and ToolInvoker + pipeline = Pipeline() + pipeline.add_component("llm", OpenAIChatGenerator(model="gpt-4o-mini", tools=[tool])) + pipeline.add_component("tool_invoker", ToolInvoker(tools=[tool])) + + # Connect components + pipeline.connect("llm.replies", "tool_invoker.messages") + + message = ChatMessage.from_user("Use the web search tool to find information about Nikola Tesla") + + # Run pipeline + result = pipeline.run({"llm": {"messages": [message]}}) + + print(result) + ``` + + """ + + def __init__(self, component: Component, name: Optional[str] = None, description: Optional[str] = None): + """ + Create a Tool instance from a Haystack component. + + :param component: The Haystack component to wrap as a tool. + :param name: Optional name for the tool (defaults to snake_case of component class name). + :param description: Optional description (defaults to component's docstring). + :raises ValueError: If the component is invalid or schema generation fails. + """ + if not isinstance(component, Component): + message = ( + f"Object {component!r} is not a Haystack component. " + "Use ComponentTool only with Haystack component instances." + ) + raise ValueError(message) + + if getattr(component, "__haystack_added_to_pipeline__", None): + msg = ( + "Component has been added to a pipeline and can't be used to create a ComponentTool. " + "Create ComponentTool from a non-pipeline component instead." + ) + raise ValueError(msg) + + # Create the tools schema from the component run method parameters + tool_schema = self._create_tool_parameters_schema(component) + + def component_invoker(**kwargs): + """ + Invokes the component using keyword arguments provided by the LLM function calling/tool-generated response. + + :param kwargs: The keyword arguments to invoke the component with. + :returns: The result of the component invocation. + """ + converted_kwargs = {} + input_sockets = component.__haystack_input__._sockets_dict + for param_name, param_value in kwargs.items(): + param_type = input_sockets[param_name].type + + # Check if the type (or list element type) has from_dict + target_type = get_args(param_type)[0] if get_origin(param_type) is list else param_type + if hasattr(target_type, "from_dict"): + if isinstance(param_value, list): + param_value = [target_type.from_dict(item) for item in param_value if isinstance(item, dict)] + elif isinstance(param_value, dict): + param_value = target_type.from_dict(param_value) + else: + # Let TypeAdapter handle both single values and lists + type_adapter = TypeAdapter(param_type) + param_value = type_adapter.validate_python(param_value) + + converted_kwargs[param_name] = param_value + logger.debug(f"Invoking component {type(component)} with kwargs: {converted_kwargs}") + return component.run(**converted_kwargs) + + # Generate a name for the tool if not provided + if not name: + class_name = component.__class__.__name__ + # Convert camelCase/PascalCase to snake_case + name = "".join( + [ + "_" + c.lower() if c.isupper() and i > 0 and not class_name[i - 1].isupper() else c.lower() + for i, c in enumerate(class_name) + ] + ).lstrip("_") + + # Generate a description for the tool if not provided and truncate to 512 characters + # as most LLMs have a limit for the description length + description = (description or component.__doc__ or name)[:512] + + # Create the Tool instance with the component invoker as the function to be called and the schema + super().__init__(name, description, tool_schema, component_invoker) + self._component = component + + def to_dict(self) -> Dict[str, Any]: + """ + Serializes the ComponentTool to a dictionary. + """ + # we do not serialize the function in this case: it can be recreated from the component at deserialization time + serialized = {"name": self.name, "description": self.description, "parameters": self.parameters} + serialized["component"] = component_to_dict(obj=self._component, name=self.name) + return {"type": generate_qualified_class_name(type(self)), "data": serialized} + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "Tool": + """ + Deserializes the ComponentTool from a dictionary. + """ + inner_data = data["data"] + component_class = import_class_by_name(inner_data["component"]["type"]) + component = component_from_dict(cls=component_class, data=inner_data["component"], name=inner_data["name"]) + return cls(component=component, name=inner_data["name"], description=inner_data["description"]) + + def _create_tool_parameters_schema(self, component: Component) -> Dict[str, Any]: + """ + Creates an OpenAI tools schema from a component's run method parameters. + + :param component: The component to create the schema from. + :raises SchemaGenerationError: If schema generation fails + :returns: OpenAI tools schema for the component's run method parameters. + """ + properties = {} + required = [] + + param_descriptions = self._get_param_descriptions(component.run) + + for input_name, socket in component.__haystack_input__._sockets_dict.items(): # type: ignore[attr-defined] + input_type = socket.type + description = param_descriptions.get(input_name, f"Input '{input_name}' for the component.") + + try: + property_schema = self._create_property_schema(input_type, description) + except Exception as e: + raise SchemaGenerationError( + f"Error processing input '{input_name}': {e}. " + f"Schema generation supports basic types (str, int, float, bool, dict), dataclasses, " + f"and lists of these types as input types for component's run method." + ) from e + + properties[input_name] = property_schema + + # Use socket.is_mandatory to check if the input is required + if socket.is_mandatory: + required.append(input_name) + + parameters_schema = {"type": "object", "properties": properties} + + if required: + parameters_schema["required"] = required + + return parameters_schema + + @staticmethod + def _get_param_descriptions(method: Callable) -> Dict[str, str]: + """ + Extracts parameter descriptions from the method's docstring using docstring_parser. + + :param method: The method to extract parameter descriptions from. + :returns: A dictionary mapping parameter names to their descriptions. + """ + docstring = getdoc(method) + if not docstring: + return {} + + docstring_parser_import.check() + parsed_doc = parse(docstring) + param_descriptions = {} + for param in parsed_doc.params: + if not param.description: + logger.warning( + "Missing description for parameter '%s'. Please add a description in the component's " + "run() method docstring using the format ':param %%s: '. " + "This description helps the LLM understand how to use this parameter." % param.arg_name + ) + param_descriptions[param.arg_name] = param.description.strip() if param.description else "" + return param_descriptions + + @staticmethod + def _is_nullable_type(python_type: Any) -> bool: + """ + Checks if the type is a Union with NoneType (i.e., Optional). + + :param python_type: The Python type to check. + :returns: True if the type is a Union with NoneType, False otherwise. + """ + origin = get_origin(python_type) + if origin is Union: + return type(None) in get_args(python_type) + return False + + def _create_list_schema(self, item_type: Any, description: str) -> Dict[str, Any]: + """ + Creates a schema for a list type. + + :param item_type: The type of items in the list. + :param description: The description of the list. + :returns: A dictionary representing the list schema. + """ + items_schema = self._create_property_schema(item_type, "") + items_schema.pop("description", None) + return {"type": "array", "description": description, "items": items_schema} + + def _create_dataclass_schema(self, python_type: Any, description: str) -> Dict[str, Any]: + """ + Creates a schema for a dataclass. + + :param python_type: The dataclass type. + :param description: The description of the dataclass. + :returns: A dictionary representing the dataclass schema. + """ + schema = {"type": "object", "description": description, "properties": {}} + cls = python_type if isinstance(python_type, type) else python_type.__class__ + for field in fields(cls): + field_description = f"Field '{field.name}' of '{cls.__name__}'." + if isinstance(schema["properties"], dict): + schema["properties"][field.name] = self._create_property_schema(field.type, field_description) + return schema + + @staticmethod + def _create_basic_type_schema(python_type: Any, description: str) -> Dict[str, Any]: + """ + Creates a schema for a basic Python type. + + :param python_type: The Python type. + :param description: The description of the type. + :returns: A dictionary representing the basic type schema. + """ + type_mapping = {str: "string", int: "integer", float: "number", bool: "boolean", dict: "object"} + return {"type": type_mapping.get(python_type, "string"), "description": description} + + def _create_property_schema(self, python_type: Any, description: str, default: Any = None) -> Dict[str, Any]: + """ + Creates a property schema for a given Python type, recursively if necessary. + + :param python_type: The Python type to create a property schema for. + :param description: The description of the property. + :param default: The default value of the property. + :returns: A dictionary representing the property schema. + :raises SchemaGenerationError: If schema generation fails, e.g., for unsupported types like Pydantic v2 models + """ + nullable = self._is_nullable_type(python_type) + if nullable: + non_none_types = [t for t in get_args(python_type) if t is not type(None)] + python_type = non_none_types[0] if non_none_types else str + + origin = get_origin(python_type) + if origin is list: + schema = self._create_list_schema(get_args(python_type)[0] if get_args(python_type) else Any, description) + elif is_dataclass(python_type): + schema = self._create_dataclass_schema(python_type, description) + elif hasattr(python_type, "model_validate"): + raise SchemaGenerationError( + f"Pydantic models (e.g. {python_type.__name__}) are not supported as input types for " + f"component's run method." + ) + else: + schema = self._create_basic_type_schema(python_type, description) + + if default is not None: + schema["default"] = default + + return schema diff --git a/pyproject.toml b/pyproject.toml index 73031b8130..258e4e2710 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -126,6 +126,9 @@ extra-dependencies = [ # Structured logging "structlog", + # ComponentTool + "docstring-parser", + # Test "pytest", "pytest-bdd", diff --git a/releasenotes/notes/add-component-tool-ffe9f9911ea055a6.yaml b/releasenotes/notes/add-component-tool-ffe9f9911ea055a6.yaml new file mode 100644 index 0000000000..c9db99438a --- /dev/null +++ b/releasenotes/notes/add-component-tool-ffe9f9911ea055a6.yaml @@ -0,0 +1,38 @@ +--- +highlights: | + Introduced ComponentTool, a powerful addition to the Haystack tooling architecture that enables any Haystack component to be used as a tool by LLMs. + ComponentTool bridges the gap between Haystack's component ecosystem and LLM tool/function calling capabilities, allowing LLMs to + directly interact with components like web search, document processing, or any custom user component. ComponentTool handles + all the complexity of schema generation and type conversion, making it easy to expose component functionality to LLMs. + +features: + - | + Introduced the ComponentTool, a new tool that wraps Haystack components allowing them to be utilized as tools for LLMs (various ChatGenerators). + This ComponentTool supports automatic tool schema generation, input type conversion, and offering support for components with run methods that have input types: + - Basic types (str, int, float, bool, dict) + - Dataclasses (both simple and nested structures) + - Lists of basic types (e.g., List[str]) + - Lists of dataclasses (e.g., List[Document]) + - Parameters with mixed types (e.g., List[Document], str etc.) + + Example usage: + ```python + from haystack.components.websearch import SerperDevWebSearch + from haystack.tools import ComponentTool + from haystack.utils import Secret + + # Create a SerperDev search component + search = SerperDevWebSearch( + api_key=Secret.from_token("your-api-key"), + top_k=3 + ) + + # Create a tool from the component + tool = ComponentTool( + component=search, + name="web_search", # Optional: defaults to "serper_dev_web_search" + description="Search the web for current information" # Optional: defaults to component docstring + ) + + # You can now use the tool now in a pipeline, see docs for more examples + ``` diff --git a/test/tools/test_component_tool.py b/test/tools/test_component_tool.py new file mode 100644 index 0000000000..38f4e20464 --- /dev/null +++ b/test/tools/test_component_tool.py @@ -0,0 +1,569 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +import json +import os +from dataclasses import dataclass +from typing import Dict, List + +import pytest + +from haystack import Pipeline, component +from haystack.components.generators.chat import OpenAIChatGenerator +from haystack.components.tools.tool_invoker import ToolInvoker +from haystack.components.websearch.serper_dev import SerperDevWebSearch +from haystack.dataclasses import ChatMessage, ChatRole, Document +from haystack.tools import ComponentTool +from haystack.utils.auth import Secret + + +### Component and Model Definitions + + +@component +class SimpleComponent: + """A simple component that generates text.""" + + @component.output_types(reply=str) + def run(self, text: str) -> Dict[str, str]: + """ + A simple component that generates text. + + :param text: user's name + :return: A dictionary with the generated text. + """ + return {"reply": f"Hello, {text}!"} + + +@dataclass +class User: + """A simple user dataclass.""" + + name: str = "Anonymous" + age: int = 0 + + +@component +class UserGreeter: + """A simple component that processes a User.""" + + @component.output_types(message=str) + def run(self, user: User) -> Dict[str, str]: + """ + A simple component that processes a User. + + :param user: The User object to process. + :return: A dictionary with a message about the user. + """ + return {"message": f"User {user.name} is {user.age} years old"} + + +@component +class ListProcessor: + """A component that processes a list of strings.""" + + @component.output_types(concatenated=str) + def run(self, texts: List[str]) -> Dict[str, str]: + """ + Concatenates a list of strings into a single string. + + :param texts: The list of strings to concatenate. + :return: A dictionary with the concatenated string. + """ + return {"concatenated": " ".join(texts)} + + +@dataclass +class Address: + """A dataclass representing a physical address.""" + + street: str + city: str + + +@dataclass +class Person: + """A person with an address.""" + + name: str + address: Address + + +@component +class PersonProcessor: + """A component that processes a Person with nested Address.""" + + @component.output_types(info=str) + def run(self, person: Person) -> Dict[str, str]: + """ + Creates information about the person. + + :param person: The Person to process. + :return: A dictionary with the person's information. + """ + return {"info": f"{person.name} lives at {person.address.street}, {person.address.city}."} + + +@component +class DocumentProcessor: + """A component that processes a list of Documents.""" + + @component.output_types(concatenated=str) + def run(self, documents: List[Document], top_k: int = 5) -> Dict[str, str]: + """ + Concatenates the content of multiple documents with newlines. + + :param documents: List of Documents whose content will be concatenated + :param top_k: The number of top documents to concatenate + :returns: Dictionary containing the concatenated document contents + """ + return {"concatenated": "\n".join(doc.content for doc in documents[:top_k])} + + +## Unit tests +class TestToolComponent: + def test_from_component_basic(self): + component = SimpleComponent() + + tool = ComponentTool(component=component) + + assert tool.name == "simple_component" + assert tool.description == "A simple component that generates text." + assert tool.parameters == { + "type": "object", + "properties": {"text": {"type": "string", "description": "user's name"}}, + "required": ["text"], + } + + # Test tool invocation + result = tool.invoke(text="world") + assert isinstance(result, dict) + assert "reply" in result + assert result["reply"] == "Hello, world!" + + def test_from_component_with_dataclass(self): + component = UserGreeter() + + tool = ComponentTool(component=component) + assert tool.parameters == { + "type": "object", + "properties": { + "user": { + "type": "object", + "description": "The User object to process.", + "properties": { + "name": {"type": "string", "description": "Field 'name' of 'User'."}, + "age": {"type": "integer", "description": "Field 'age' of 'User'."}, + }, + } + }, + "required": ["user"], + } + + assert tool.name == "user_greeter" + assert tool.description == "A simple component that processes a User." + + # Test tool invocation + result = tool.invoke(user={"name": "Alice", "age": 30}) + assert isinstance(result, dict) + assert "message" in result + assert result["message"] == "User Alice is 30 years old" + + def test_from_component_with_list_input(self): + component = ListProcessor() + + tool = ComponentTool( + component=component, name="list_processing_tool", description="A tool that concatenates strings" + ) + + assert tool.parameters == { + "type": "object", + "properties": { + "texts": { + "type": "array", + "description": "The list of strings to concatenate.", + "items": {"type": "string"}, + } + }, + "required": ["texts"], + } + + # Test tool invocation + result = tool.invoke(texts=["hello", "world"]) + assert isinstance(result, dict) + assert "concatenated" in result + assert result["concatenated"] == "hello world" + + def test_from_component_with_nested_dataclass(self): + component = PersonProcessor() + + tool = ComponentTool(component=component, name="person_tool", description="A tool that processes people") + + assert tool.parameters == { + "type": "object", + "properties": { + "person": { + "type": "object", + "description": "The Person to process.", + "properties": { + "name": {"type": "string", "description": "Field 'name' of 'Person'."}, + "address": { + "type": "object", + "description": "Field 'address' of 'Person'.", + "properties": { + "street": {"type": "string", "description": "Field 'street' of 'Address'."}, + "city": {"type": "string", "description": "Field 'city' of 'Address'."}, + }, + }, + }, + } + }, + "required": ["person"], + } + + # Test tool invocation + result = tool.invoke(person={"name": "Diana", "address": {"street": "123 Elm Street", "city": "Metropolis"}}) + assert isinstance(result, dict) + assert "info" in result + assert result["info"] == "Diana lives at 123 Elm Street, Metropolis." + + def test_from_component_with_document_list(self): + component = DocumentProcessor() + + tool = ComponentTool( + component=component, name="document_processor", description="A tool that concatenates document contents" + ) + + assert tool.parameters == { + "type": "object", + "properties": { + "documents": { + "type": "array", + "description": "List of Documents whose content will be concatenated", + "items": { + "type": "object", + "properties": { + "id": {"type": "string", "description": "Field 'id' of 'Document'."}, + "content": {"type": "string", "description": "Field 'content' of 'Document'."}, + "dataframe": {"type": "string", "description": "Field 'dataframe' of 'Document'."}, + "blob": { + "type": "object", + "description": "Field 'blob' of 'Document'.", + "properties": { + "data": {"type": "string", "description": "Field 'data' of 'ByteStream'."}, + "meta": {"type": "string", "description": "Field 'meta' of 'ByteStream'."}, + "mime_type": { + "type": "string", + "description": "Field 'mime_type' of 'ByteStream'.", + }, + }, + }, + "meta": {"type": "string", "description": "Field 'meta' of 'Document'."}, + "score": {"type": "number", "description": "Field 'score' of 'Document'."}, + "embedding": { + "type": "array", + "description": "Field 'embedding' of 'Document'.", + "items": {"type": "number"}, + }, + "sparse_embedding": { + "type": "object", + "description": "Field 'sparse_embedding' of 'Document'.", + "properties": { + "indices": { + "type": "array", + "description": "Field 'indices' of 'SparseEmbedding'.", + "items": {"type": "integer"}, + }, + "values": { + "type": "array", + "description": "Field 'values' of 'SparseEmbedding'.", + "items": {"type": "number"}, + }, + }, + }, + }, + }, + }, + "top_k": {"description": "The number of top documents to concatenate", "type": "integer"}, + }, + "required": ["documents"], + } + + # Test tool invocation + result = tool.invoke(documents=[{"content": "First document"}, {"content": "Second document"}]) + assert isinstance(result, dict) + assert "concatenated" in result + assert result["concatenated"] == "First document\nSecond document" + + def test_from_component_with_non_component(self): + class NotAComponent: + def foo(self, text: str): + return {"reply": f"Hello, {text}!"} + + not_a_component = NotAComponent() + + with pytest.raises(ValueError): + ComponentTool(component=not_a_component, name="invalid_tool", description="This should fail") + + +## Integration tests +class TestToolComponentInPipelineWithOpenAI: + @pytest.mark.skipif(not os.environ.get("OPENAI_API_KEY"), reason="OPENAI_API_KEY not set") + @pytest.mark.integration + def test_component_tool_in_pipeline(self): + # Create component and convert it to tool + component = SimpleComponent() + tool = ComponentTool( + component=component, name="hello_tool", description="A tool that generates a greeting message for the user" + ) + + # Create pipeline with OpenAIChatGenerator and ToolInvoker + pipeline = Pipeline() + pipeline.add_component("llm", OpenAIChatGenerator(model="gpt-4o-mini", tools=[tool])) + pipeline.add_component("tool_invoker", ToolInvoker(tools=[tool])) + + # Connect components + pipeline.connect("llm.replies", "tool_invoker.messages") + + message = ChatMessage.from_user(text="Vladimir") + + # Run pipeline + result = pipeline.run({"llm": {"messages": [message]}}) + + # Check results + tool_messages = result["tool_invoker"]["tool_messages"] + assert len(tool_messages) == 1 + + tool_message = tool_messages[0] + assert tool_message.is_from(ChatRole.TOOL) + assert "Vladimir" in tool_message.tool_call_result.result + assert not tool_message.tool_call_result.error + + @pytest.mark.skipif(not os.environ.get("OPENAI_API_KEY"), reason="OPENAI_API_KEY not set") + @pytest.mark.integration + def test_user_greeter_in_pipeline(self): + component = UserGreeter() + tool = ComponentTool( + component=component, name="user_greeter", description="A tool that greets users with their name and age" + ) + + pipeline = Pipeline() + pipeline.add_component("llm", OpenAIChatGenerator(model="gpt-4o-mini", tools=[tool])) + pipeline.add_component("tool_invoker", ToolInvoker(tools=[tool])) + pipeline.connect("llm.replies", "tool_invoker.messages") + + message = ChatMessage.from_user(text="I am Alice and I'm 30 years old") + + result = pipeline.run({"llm": {"messages": [message]}}) + tool_messages = result["tool_invoker"]["tool_messages"] + assert len(tool_messages) == 1 + + tool_message = tool_messages[0] + assert tool_message.is_from(ChatRole.TOOL) + assert tool_message.tool_call_result.result == str({"message": "User Alice is 30 years old"}) + assert not tool_message.tool_call_result.error + + @pytest.mark.skipif(not os.environ.get("OPENAI_API_KEY"), reason="OPENAI_API_KEY not set") + @pytest.mark.integration + def test_list_processor_in_pipeline(self): + component = ListProcessor() + tool = ComponentTool( + component=component, name="list_processor", description="A tool that concatenates a list of strings" + ) + + pipeline = Pipeline() + pipeline.add_component("llm", OpenAIChatGenerator(model="gpt-4o-mini", tools=[tool])) + pipeline.add_component("tool_invoker", ToolInvoker(tools=[tool])) + pipeline.connect("llm.replies", "tool_invoker.messages") + + message = ChatMessage.from_user(text="Can you join these words: hello, beautiful, world") + + result = pipeline.run({"llm": {"messages": [message]}}) + tool_messages = result["tool_invoker"]["tool_messages"] + assert len(tool_messages) == 1 + + tool_message = tool_messages[0] + assert tool_message.is_from(ChatRole.TOOL) + assert tool_message.tool_call_result.result == str({"concatenated": "hello beautiful world"}) + assert not tool_message.tool_call_result.error + + @pytest.mark.skipif(not os.environ.get("OPENAI_API_KEY"), reason="OPENAI_API_KEY not set") + @pytest.mark.integration + def test_person_processor_in_pipeline(self): + component = PersonProcessor() + tool = ComponentTool( + component=component, + name="person_processor", + description="A tool that processes information about a person and their address", + ) + + pipeline = Pipeline() + pipeline.add_component("llm", OpenAIChatGenerator(model="gpt-4o-mini", tools=[tool])) + pipeline.add_component("tool_invoker", ToolInvoker(tools=[tool])) + pipeline.connect("llm.replies", "tool_invoker.messages") + + message = ChatMessage.from_user(text="Diana lives at 123 Elm Street in Metropolis") + + result = pipeline.run({"llm": {"messages": [message]}}) + tool_messages = result["tool_invoker"]["tool_messages"] + assert len(tool_messages) == 1 + + tool_message = tool_messages[0] + assert tool_message.is_from(ChatRole.TOOL) + assert "Diana" in tool_message.tool_call_result.result and "Metropolis" in tool_message.tool_call_result.result + assert not tool_message.tool_call_result.error + + @pytest.mark.skipif(not os.environ.get("OPENAI_API_KEY"), reason="OPENAI_API_KEY not set") + @pytest.mark.integration + def test_document_processor_in_pipeline(self): + component = DocumentProcessor() + tool = ComponentTool( + component=component, + name="document_processor", + description="A tool that concatenates the content of multiple documents", + ) + + pipeline = Pipeline() + pipeline.add_component("llm", OpenAIChatGenerator(model="gpt-4o-mini", tools=[tool])) + pipeline.add_component("tool_invoker", ToolInvoker(tools=[tool], convert_result_to_json_string=True)) + pipeline.connect("llm.replies", "tool_invoker.messages") + + message = ChatMessage.from_user( + text="Concatenate these documents: First one says 'Hello world' and second one says 'Goodbye world' and third one says 'Hello again', but use top_k=2. Set only content field of the document only. Do not set id, meta, score, embedding, sparse_embedding, dataframe, blob fields." + ) + + result = pipeline.run({"llm": {"messages": [message]}}) + + tool_messages = result["tool_invoker"]["tool_messages"] + assert len(tool_messages) == 1 + + tool_message = tool_messages[0] + assert tool_message.is_from(ChatRole.TOOL) + result = json.loads(tool_message.tool_call_result.result) + assert "concatenated" in result + assert "Hello world" in result["concatenated"] + assert "Goodbye world" in result["concatenated"] + assert not tool_message.tool_call_result.error + + @pytest.mark.skipif(not os.environ.get("OPENAI_API_KEY"), reason="OPENAI_API_KEY not set") + @pytest.mark.integration + def test_lost_in_middle_ranker_in_pipeline(self): + from haystack.components.rankers import LostInTheMiddleRanker + + component = LostInTheMiddleRanker() + tool = ComponentTool( + component=component, + name="lost_in_middle_ranker", + description="A tool that ranks documents using the Lost in the Middle algorithm and returns top k results", + ) + + pipeline = Pipeline() + pipeline.add_component("llm", OpenAIChatGenerator(model="gpt-4o-mini", tools=[tool])) + pipeline.add_component("tool_invoker", ToolInvoker(tools=[tool])) + pipeline.connect("llm.replies", "tool_invoker.messages") + + message = ChatMessage.from_user( + text="I have three documents with content: 'First doc', 'Middle doc', and 'Last doc'. Rank them top_k=2. Set only content field of the document only. Do not set id, meta, score, embedding, sparse_embedding, dataframe, blob fields." + ) + + result = pipeline.run({"llm": {"messages": [message]}}) + + tool_messages = result["tool_invoker"]["tool_messages"] + assert len(tool_messages) == 1 + tool_message = tool_messages[0] + assert tool_message.is_from(ChatRole.TOOL) + + @pytest.mark.skipif(not os.environ.get("OPENAI_API_KEY"), reason="OPENAI_API_KEY not set") + @pytest.mark.skipif(not os.environ.get("SERPERDEV_API_KEY"), reason="SERPERDEV_API_KEY not set") + @pytest.mark.integration + def test_serper_dev_web_search_in_pipeline(self): + component = SerperDevWebSearch(api_key=Secret.from_env_var("SERPERDEV_API_KEY"), top_k=3) + tool = ComponentTool( + component=component, name="web_search", description="Search the web for current information on any topic" + ) + + pipeline = Pipeline() + pipeline.add_component("llm", OpenAIChatGenerator(model="gpt-4o-mini", tools=[tool])) + pipeline.add_component("tool_invoker", ToolInvoker(tools=[tool])) + pipeline.connect("llm.replies", "tool_invoker.messages") + + result = pipeline.run( + { + "llm": { + "messages": [ + ChatMessage.from_user(text="Use the web search tool to find information about Nikola Tesla") + ] + } + } + ) + + assert len(result["tool_invoker"]["tool_messages"]) == 1 + tool_message = result["tool_invoker"]["tool_messages"][0] + assert tool_message.is_from(ChatRole.TOOL) + assert "Nikola Tesla" in tool_message.tool_call_result.result + assert not tool_message.tool_call_result.error + + def test_serde_in_pipeline(self, monkeypatch): + monkeypatch.setenv("SERPERDEV_API_KEY", "test-key") + monkeypatch.setenv("OPENAI_API_KEY", "test-key") + + # Create the search component and tool + search = SerperDevWebSearch(top_k=3) + tool = ComponentTool(component=search, name="web_search", description="Search the web for current information") + + # Create and configure the pipeline + pipeline = Pipeline() + pipeline.add_component("tool_invoker", ToolInvoker(tools=[tool])) + pipeline.add_component("llm", OpenAIChatGenerator(model="gpt-4o-mini", tools=[tool])) + pipeline.connect("tool_invoker.tool_messages", "llm.messages") + + # Serialize to dict and verify structure + pipeline_dict = pipeline.to_dict() + assert ( + pipeline_dict["components"]["tool_invoker"]["type"] == "haystack.components.tools.tool_invoker.ToolInvoker" + ) + assert len(pipeline_dict["components"]["tool_invoker"]["init_parameters"]["tools"]) == 1 + + tool_dict = pipeline_dict["components"]["tool_invoker"]["init_parameters"]["tools"][0] + assert tool_dict["type"] == "haystack.tools.component_tool.ComponentTool" + assert tool_dict["data"]["name"] == "web_search" + assert tool_dict["data"]["component"]["type"] == "haystack.components.websearch.serper_dev.SerperDevWebSearch" + assert tool_dict["data"]["component"]["init_parameters"]["top_k"] == 3 + assert tool_dict["data"]["component"]["init_parameters"]["api_key"]["type"] == "env_var" + + # Test round-trip serialization + pipeline_yaml = pipeline.dumps() + new_pipeline = Pipeline.loads(pipeline_yaml) + assert new_pipeline == pipeline + + def test_component_tool_serde(self): + component = SimpleComponent() + + tool = ComponentTool(component=component, name="simple_tool", description="A simple tool") + + # Test serialization + tool_dict = tool.to_dict() + assert tool_dict["type"] == "haystack.tools.component_tool.ComponentTool" + assert tool_dict["data"]["name"] == "simple_tool" + assert tool_dict["data"]["description"] == "A simple tool" + assert "component" in tool_dict["data"] + + # Test deserialization + new_tool = ComponentTool.from_dict(tool_dict) + assert new_tool.name == tool.name + assert new_tool.description == tool.description + assert new_tool.parameters == tool.parameters + assert isinstance(new_tool._component, SimpleComponent) + + def test_pipeline_component_fails(self): + component = SimpleComponent() + + # Create a pipeline and add the component to it + pipeline = Pipeline() + pipeline.add_component("simple", component) + + # Try to create a tool from the component and it should fail because the component has been added to a pipeline and + # thus can't be used as tool + with pytest.raises(ValueError, match="Component has been added to a pipeline"): + ComponentTool(component=component)