Skip to content

Commit

Permalink
refactor: move Tool to a separate package; refactor serde (#8690)
Browse files Browse the repository at this point in the history
* move tool to separate package; refactor serde

* release note

* rm unused import
  • Loading branch information
anakin87 authored Jan 9, 2025
1 parent 28ad78c commit 3f15f38
Show file tree
Hide file tree
Showing 14 changed files with 127 additions and 56 deletions.
2 changes: 1 addition & 1 deletion docs/pydoc/config/data_classess_api.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ loaders:
- type: haystack_pydoc_tools.loaders.CustomPythonLoader
search_path: [../../../haystack/dataclasses]
modules:
["answer", "byte_stream", "chat_message", "document", "streaming_chunk", "sparse_embedding", "tool"]
["answer", "byte_stream", "chat_message", "document", "streaming_chunk", "sparse_embedding",]
ignore_when_discovered: ["__init__"]
processors:
- type: filter
Expand Down
27 changes: 27 additions & 0 deletions docs/pydoc/config/tool_components_api.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
loaders:
- type: haystack_pydoc_tools.loaders.CustomPythonLoader
search_path: [../../../haystack/components/tools]
modules: ["tool_invoker"]
ignore_when_discovered: ["__init__"]
processors:
- type: filter
expression:
documented_only: true
do_not_filter_modules: false
skip_empty_modules: true
- type: smart
- type: crossref
renderer:
type: haystack_pydoc_tools.renderers.ReadmeCoreRenderer
excerpt: Components related to Tool Calling.
category_slug: haystack-api
title: Tool Components
slug: tool-components-api
order: 152
markdown:
descriptive_class_title: false
classdef_code_block: false
descriptive_module_title: true
add_method_class_prefix: true
add_member_class_prefix: false
filename: tool_components_api.md
9 changes: 5 additions & 4 deletions docs/pydoc/config/tools_api.yml
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
loaders:
- type: haystack_pydoc_tools.loaders.CustomPythonLoader
search_path: [../../../haystack/components/tools]
modules: ["tool_invoker"]
search_path: [../../../haystack/tools]
modules:
["tool"]
ignore_when_discovered: ["__init__"]
processors:
- type: filter
Expand All @@ -13,11 +14,11 @@ processors:
- type: crossref
renderer:
type: haystack_pydoc_tools.renderers.ReadmeCoreRenderer
excerpt: Components related to Tool Calling.
excerpt: Unified abstractions to represent tools across the framework.
category_slug: haystack-api
title: Tools
slug: tools-api
order: 152
order: 151
markdown:
descriptive_class_title: false
classdef_code_block: false
Expand Down
2 changes: 1 addition & 1 deletion haystack/components/generators/chat/hugging_face_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@

from haystack import component, default_from_dict, default_to_dict, logging
from haystack.dataclasses import ChatMessage, StreamingChunk, ToolCall
from haystack.dataclasses.tool import Tool, _check_duplicate_tool_names, deserialize_tools_inplace
from haystack.lazy_imports import LazyImport
from haystack.tools.tool import Tool, _check_duplicate_tool_names, deserialize_tools_inplace
from haystack.utils import Secret, deserialize_callable, deserialize_secrets_inplace, serialize_callable
from haystack.utils.hf import HFGenerationAPIType, HFModelType, check_valid_model, convert_message_to_hf_format
from haystack.utils.url_validation import is_valid_http_url
Expand Down
2 changes: 1 addition & 1 deletion haystack/components/generators/chat/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from haystack import component, default_from_dict, default_to_dict, logging
from haystack.dataclasses import ChatMessage, StreamingChunk, ToolCall
from haystack.dataclasses.tool import Tool, _check_duplicate_tool_names, deserialize_tools_inplace
from haystack.tools.tool import Tool, _check_duplicate_tool_names, deserialize_tools_inplace
from haystack.utils import Secret, deserialize_callable, deserialize_secrets_inplace, serialize_callable

logger = logging.getLogger(__name__)
Expand Down
2 changes: 1 addition & 1 deletion haystack/components/tools/tool_invoker.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from haystack import component, default_from_dict, default_to_dict, logging
from haystack.dataclasses.chat_message import ChatMessage, ToolCall
from haystack.dataclasses.tool import Tool, ToolInvocationError, _check_duplicate_tool_names, deserialize_tools_inplace
from haystack.tools.tool import Tool, ToolInvocationError, _check_duplicate_tool_names, deserialize_tools_inplace

logger = logging.getLogger(__name__)

Expand Down
2 changes: 0 additions & 2 deletions haystack/dataclasses/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from haystack.dataclasses.document import Document
from haystack.dataclasses.sparse_embedding import SparseEmbedding
from haystack.dataclasses.streaming_chunk import StreamingChunk
from haystack.dataclasses.tool import Tool

__all__ = [
"Document",
Expand All @@ -23,5 +22,4 @@
"TextContent",
"StreamingChunk",
"SparseEmbedding",
"Tool",
]
7 changes: 7 additions & 0 deletions haystack/tools/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# SPDX-FileCopyrightText: 2022-present deepset GmbH <[email protected]>
#
# SPDX-License-Identifier: Apache-2.0

from haystack.tools.tool import Tool, _check_duplicate_tool_names, deserialize_tools_inplace

__all__ = ["Tool", "_check_duplicate_tool_names", "deserialize_tools_inplace"]
20 changes: 14 additions & 6 deletions haystack/dataclasses/tool.py → haystack/tools/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from pydantic import create_model

from haystack.core.serialization import generate_qualified_class_name, import_class_by_name
from haystack.lazy_imports import LazyImport
from haystack.utils import deserialize_callable, serialize_callable

Expand Down Expand Up @@ -89,9 +90,9 @@ def to_dict(self) -> Dict[str, Any]:
Dictionary with serialized data.
"""

serialized = asdict(self)
serialized["function"] = serialize_callable(self.function)
return serialized
data = asdict(self)
data["function"] = serialize_callable(self.function)
return {"type": generate_qualified_class_name(type(self)), "data": data}

@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "Tool":
Expand All @@ -103,8 +104,9 @@ def from_dict(cls, data: Dict[str, Any]) -> "Tool":
:returns:
Deserialized Tool.
"""
data["function"] = deserialize_callable(data["function"])
return cls(**data)
init_parameters = data["data"]
init_parameters["function"] = deserialize_callable(init_parameters["function"])
return cls(**init_parameters)

@classmethod
def from_function(cls, function: Callable, name: Optional[str] = None, description: Optional[str] = None) -> "Tool":
Expand Down Expand Up @@ -253,6 +255,12 @@ def deserialize_tools_inplace(data: Dict[str, Any], key: str = "tools"):
for tool in serialized_tools:
if not isinstance(tool, dict):
raise TypeError(f"Serialized tool '{tool}' is not a dictionary")
deserialized_tools.append(Tool.from_dict(tool))

# different classes are allowed: Tool, ComponentTool, etc.
tool_class = import_class_by_name(tool["type"])
if not issubclass(tool_class, Tool):
raise TypeError(f"Class '{tool_class}' is not a subclass of Tool")

deserialized_tools.append(tool_class.from_dict(tool))

data[key] = deserialized_tools
5 changes: 5 additions & 0 deletions releasenotes/notes/tool-refactor-7ed98e3ee4de14c3.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
enhancements:
- |
Move `Tool` to a new dedicated `tools` package.
Refactor `Tool` serialization and deserialization to make it more flexible and include type information.
25 changes: 16 additions & 9 deletions test/components/generators/chat/test_hugging_face_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@
from huggingface_hub.utils import RepositoryNotFoundError

from haystack.components.generators.chat.hugging_face_api import HuggingFaceAPIChatGenerator
from haystack.dataclasses import ChatMessage, Tool, ToolCall
from haystack.tools import Tool
from haystack.dataclasses import ChatMessage, ToolCall


@pytest.fixture
Expand Down Expand Up @@ -217,10 +218,13 @@ def test_to_dict(self, mock_check_valid_model):
assert init_params["streaming_callback"] is None
assert init_params["tools"] == [
{
"description": "description",
"function": "builtins.print",
"name": "name",
"parameters": {"x": {"type": "string"}},
"type": "haystack.tools.tool.Tool",
"data": {
"description": "description",
"function": "builtins.print",
"name": "name",
"parameters": {"x": {"type": "string"}},
},
}
]

Expand Down Expand Up @@ -276,10 +280,13 @@ def test_serde_in_pipeline(self, mock_check_valid_model):
"streaming_callback": None,
"tools": [
{
"name": "name",
"description": "description",
"parameters": {"x": {"type": "string"}},
"function": "builtins.print",
"type": "haystack.tools.tool.Tool",
"data": {
"name": "name",
"description": "description",
"parameters": {"x": {"type": "string"}},
"function": "builtins.print",
},
}
],
},
Expand Down
25 changes: 16 additions & 9 deletions test/components/generators/chat/test_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
from haystack.components.generators.utils import print_streaming_chunk
from haystack.dataclasses import StreamingChunk
from haystack.utils.auth import Secret
from haystack.dataclasses import ChatMessage, Tool, ToolCall
from haystack.dataclasses import ChatMessage, ToolCall
from haystack.tools import Tool
from haystack.components.generators.chat.openai import OpenAIChatGenerator


Expand Down Expand Up @@ -200,10 +201,13 @@ def test_to_dict_with_parameters(self, monkeypatch):
"generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"},
"tools": [
{
"description": "description",
"function": "builtins.print",
"name": "name",
"parameters": {"x": {"type": "string"}},
"type": "haystack.tools.tool.Tool",
"data": {
"description": "description",
"function": "builtins.print",
"name": "name",
"parameters": {"x": {"type": "string"}},
},
}
],
"tools_strict": True,
Expand All @@ -224,10 +228,13 @@ def test_from_dict(self, monkeypatch):
"generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"},
"tools": [
{
"description": "description",
"function": "builtins.print",
"name": "name",
"parameters": {"x": {"type": "string"}},
"type": "haystack.tools.tool.Tool",
"data": {
"description": "description",
"function": "builtins.print",
"name": "name",
"parameters": {"x": {"type": "string"}},
},
}
],
"tools_strict": True,
Expand Down
19 changes: 11 additions & 8 deletions test/components/tools/test_tool_invoker.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from haystack import Pipeline

from haystack.dataclasses import ChatMessage, ToolCall, ToolCallResult, ChatRole
from haystack.dataclasses.tool import Tool, ToolInvocationError
from haystack.tools.tool import Tool, ToolInvocationError
from haystack.components.tools.tool_invoker import ToolInvoker, ToolNotFoundException, StringConversionError
from haystack.components.generators.chat.openai import OpenAIChatGenerator

Expand Down Expand Up @@ -238,14 +238,17 @@ def test_serde_in_pipeline(self, invoker, monkeypatch):
"init_parameters": {
"tools": [
{
"name": "weather_tool",
"description": "Provides weather information for a given location.",
"parameters": {
"type": "object",
"properties": {"location": {"type": "string"}},
"required": ["location"],
"type": "haystack.tools.tool.Tool",
"data": {
"name": "weather_tool",
"description": "Provides weather information for a given location.",
"parameters": {
"type": "object",
"properties": {"location": {"type": "string"}},
"required": ["location"],
},
"function": "tools.test_tool_invoker.weather_function",
},
"function": "tools.test_tool_invoker.weather_function",
}
],
"raise_on_failure": True,
Expand Down
36 changes: 22 additions & 14 deletions test/dataclasses/test_tool.py → test/tools/test_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@
from typing import Literal, Optional

import pytest

from haystack.dataclasses.tool import (
from haystack.tools.tool import (
SchemaGenerationError,
Tool,
ToolInvocationError,
Expand Down Expand Up @@ -78,18 +77,24 @@ def test_to_dict(self):
)

assert tool.to_dict() == {
"name": "weather",
"description": "Get weather report",
"parameters": parameters,
"function": "test_tool.get_weather_report",
"type": "haystack.tools.tool.Tool",
"data": {
"name": "weather",
"description": "Get weather report",
"parameters": parameters,
"function": "test_tool.get_weather_report",
},
}

def test_from_dict(self):
tool_dict = {
"name": "weather",
"description": "Get weather report",
"parameters": parameters,
"function": "test_tool.get_weather_report",
"type": "haystack.tools.tool.Tool",
"data": {
"name": "weather",
"description": "Get weather report",
"parameters": parameters,
"function": "test_tool.get_weather_report",
},
}

tool = Tool.from_dict(tool_dict)
Expand Down Expand Up @@ -179,14 +184,12 @@ def function_with_annotations(

def test_deserialize_tools_inplace():
tool = Tool(name="weather", description="Get weather report", parameters=parameters, function=get_weather_report)
serialized_tool = tool.to_dict()
print(serialized_tool)

data = {"tools": [serialized_tool.copy()]}
data = {"tools": [tool.to_dict()]}
deserialize_tools_inplace(data)
assert data["tools"] == [tool]

data = {"mytools": [serialized_tool.copy()]}
data = {"mytools": [tool.to_dict()]}
deserialize_tools_inplace(data, key="mytools")
assert data["mytools"] == [tool]

Expand All @@ -212,6 +215,11 @@ def test_deserialize_tools_inplace_failures():
with pytest.raises(TypeError):
deserialize_tools_inplace(data)

# not a subclass of Tool
data = {"tools": [{"type": "haystack.dataclasses.ChatMessage", "data": {"irrelevant": "irrelevant"}}]}
with pytest.raises(TypeError):
deserialize_tools_inplace(data)


def test_remove_title_from_schema():
complex_schema = {
Expand Down

0 comments on commit 3f15f38

Please sign in to comment.