-
Notifications
You must be signed in to change notification settings - Fork 2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: Tool
dataclass - unified abstraction to represent tools
#8652
Merged
Merged
Changes from all commits
Commits
Show all changes
20 commits
Select commit
Hold shift + click to select a range
4b941b0
draft
anakin87 9ec7de7
del HF token in tests
anakin87 4865ab0
Merge branch 'fix-hf-token-test' into new-chatmessage
anakin87 7b6e9d2
adaptations
anakin87 c462ddc
progress
anakin87 94a103a
Merge branch 'main' into new-chatmessage
anakin87 873ae4f
fix type
anakin87 fe6c4c8
import sorting
anakin87 1a5b46c
more control on deserialization
anakin87 e3f4c89
release note
anakin87 2370c2f
Merge branch 'main' into new-chatmessage
anakin87 180d0f3
improvements
anakin87 328cebd
support name field
anakin87 b88daae
fix chatpromptbuilder test
anakin87 a83b10c
port Tool from experimental
anakin87 b3c9381
release note
anakin87 65e73fd
Merge branch 'new-chatmessage' into tool-dataclass
anakin87 8a87019
Merge branch 'main' into tool-dataclass
anakin87 c50a119
docs upd
dfokina b35a568
Update tool.py
dfokina File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,243 @@ | ||
# SPDX-FileCopyrightText: 2022-present deepset GmbH <[email protected]> | ||
# | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import inspect | ||
from dataclasses import asdict, dataclass | ||
from typing import Any, Callable, Dict, Optional | ||
|
||
from pydantic import create_model | ||
|
||
from haystack.lazy_imports import LazyImport | ||
from haystack.utils import deserialize_callable, serialize_callable | ||
|
||
with LazyImport(message="Run 'pip install jsonschema'") as jsonschema_import: | ||
from jsonschema import Draft202012Validator | ||
from jsonschema.exceptions import SchemaError | ||
|
||
|
||
class ToolInvocationError(Exception): | ||
""" | ||
Exception raised when a Tool invocation fails. | ||
""" | ||
|
||
pass | ||
|
||
|
||
class SchemaGenerationError(Exception): | ||
""" | ||
Exception raised when automatic schema generation fails. | ||
""" | ||
|
||
pass | ||
|
||
|
||
@dataclass | ||
class Tool: | ||
""" | ||
Data class representing a Tool that Language Models can prepare a call for. | ||
|
||
Accurate definitions of the textual attributes such as `name` and `description` | ||
are important for the Language Model to correctly prepare the call. | ||
|
||
:param name: | ||
Name of the Tool. | ||
:param description: | ||
Description of the Tool. | ||
:param parameters: | ||
A JSON schema defining the parameters expected by the Tool. | ||
:param function: | ||
The function that will be invoked when the Tool is called. | ||
""" | ||
|
||
name: str | ||
description: str | ||
parameters: Dict[str, Any] | ||
function: Callable | ||
|
||
def __post_init__(self): | ||
jsonschema_import.check() | ||
# Check that the parameters define a valid JSON schema | ||
try: | ||
Draft202012Validator.check_schema(self.parameters) | ||
except SchemaError as e: | ||
raise ValueError("The provided parameters do not define a valid JSON schema") from e | ||
|
||
@property | ||
def tool_spec(self) -> Dict[str, Any]: | ||
""" | ||
Return the Tool specification to be used by the Language Model. | ||
""" | ||
return {"name": self.name, "description": self.description, "parameters": self.parameters} | ||
|
||
def invoke(self, **kwargs) -> Any: | ||
""" | ||
Invoke the Tool with the provided keyword arguments. | ||
""" | ||
|
||
try: | ||
result = self.function(**kwargs) | ||
except Exception as e: | ||
raise ToolInvocationError(f"Failed to invoke Tool `{self.name}` with parameters {kwargs}") from e | ||
return result | ||
|
||
def to_dict(self) -> Dict[str, Any]: | ||
""" | ||
Serializes the Tool to a dictionary. | ||
|
||
:returns: | ||
Dictionary with serialized data. | ||
""" | ||
|
||
serialized = asdict(self) | ||
serialized["function"] = serialize_callable(self.function) | ||
return serialized | ||
|
||
@classmethod | ||
def from_dict(cls, data: Dict[str, Any]) -> "Tool": | ||
""" | ||
Deserializes the Tool from a dictionary. | ||
|
||
:param data: | ||
Dictionary to deserialize from. | ||
:returns: | ||
Deserialized Tool. | ||
""" | ||
data["function"] = deserialize_callable(data["function"]) | ||
return cls(**data) | ||
|
||
@classmethod | ||
def from_function(cls, function: Callable, name: Optional[str] = None, description: Optional[str] = None) -> "Tool": | ||
""" | ||
Create a Tool instance from a function. | ||
|
||
### Usage example | ||
|
||
```python | ||
from typing import Annotated, Literal | ||
from haystack.dataclasses import Tool | ||
|
||
def get_weather( | ||
city: Annotated[str, "the city for which to get the weather"] = "Munich", | ||
unit: Annotated[Literal["Celsius", "Fahrenheit"], "the unit for the temperature"] = "Celsius"): | ||
'''A simple function to get the current weather for a location.''' | ||
return f"Weather report for {city}: 20 {unit}, sunny" | ||
|
||
tool = Tool.from_function(get_weather) | ||
|
||
print(tool) | ||
>>> Tool(name='get_weather', description='A simple function to get the current weather for a location.', | ||
>>> parameters={ | ||
>>> 'type': 'object', | ||
>>> 'properties': { | ||
>>> 'city': {'type': 'string', 'description': 'the city for which to get the weather', 'default': 'Munich'}, | ||
>>> 'unit': { | ||
>>> 'type': 'string', | ||
>>> 'enum': ['Celsius', 'Fahrenheit'], | ||
>>> 'description': 'the unit for the temperature', | ||
>>> 'default': 'Celsius', | ||
>>> }, | ||
>>> } | ||
>>> }, | ||
>>> function=<function get_weather at 0x7f7b3a8a9b80>) | ||
``` | ||
|
||
:param function: | ||
The function to be converted into a Tool. | ||
The function must include type hints for all parameters. | ||
If a parameter is annotated using `typing.Annotated`, its metadata will be used as parameter description. | ||
:param name: | ||
The name of the Tool. If not provided, the name of the function will be used. | ||
:param description: | ||
The description of the Tool. If not provided, the docstring of the function will be used. | ||
To intentionally leave the description empty, pass an empty string. | ||
|
||
:returns: | ||
The Tool created from the function. | ||
|
||
:raises ValueError: | ||
If any parameter of the function lacks a type hint. | ||
:raises SchemaGenerationError: | ||
If there is an error generating the JSON schema for the Tool. | ||
""" | ||
|
||
tool_description = description if description is not None else (function.__doc__ or "") | ||
|
||
signature = inspect.signature(function) | ||
|
||
# collect fields (types and defaults) and descriptions from function parameters | ||
fields: Dict[str, Any] = {} | ||
descriptions = {} | ||
|
||
for param_name, param in signature.parameters.items(): | ||
if param.annotation is param.empty: | ||
raise ValueError(f"Function '{function.__name__}': parameter '{param_name}' does not have a type hint.") | ||
|
||
# if the parameter has not a default value, Pydantic requires an Ellipsis (...) | ||
# to explicitly indicate that the parameter is required | ||
default = param.default if param.default is not param.empty else ... | ||
fields[param_name] = (param.annotation, default) | ||
|
||
if hasattr(param.annotation, "__metadata__"): | ||
descriptions[param_name] = param.annotation.__metadata__[0] | ||
|
||
# create Pydantic model and generate JSON schema | ||
try: | ||
model = create_model(function.__name__, **fields) | ||
schema = model.model_json_schema() | ||
except Exception as e: | ||
raise SchemaGenerationError(f"Failed to create JSON schema for function '{function.__name__}'") from e | ||
|
||
# we don't want to include title keywords in the schema, as they contain redundant information | ||
# there is no programmatic way to prevent Pydantic from adding them, so we remove them later | ||
# see https://github.com/pydantic/pydantic/discussions/8504 | ||
_remove_title_from_schema(schema) | ||
|
||
# add parameters descriptions to the schema | ||
for param_name, param_description in descriptions.items(): | ||
if param_name in schema["properties"]: | ||
schema["properties"][param_name]["description"] = param_description | ||
|
||
return Tool(name=name or function.__name__, description=tool_description, parameters=schema, function=function) | ||
|
||
|
||
def _remove_title_from_schema(schema: Dict[str, Any]): | ||
""" | ||
Remove the 'title' keyword from JSON schema and contained property schemas. | ||
|
||
:param schema: | ||
The JSON schema to remove the 'title' keyword from. | ||
""" | ||
schema.pop("title", None) | ||
|
||
for property_schema in schema["properties"].values(): | ||
for key in list(property_schema.keys()): | ||
if key == "title": | ||
del property_schema[key] | ||
|
||
|
||
def deserialize_tools_inplace(data: Dict[str, Any], key: str = "tools"): | ||
""" | ||
Deserialize Tools in a dictionary inplace. | ||
|
||
:param data: | ||
The dictionary with the serialized data. | ||
:param key: | ||
The key in the dictionary where the Tools are stored. | ||
""" | ||
if key in data: | ||
serialized_tools = data[key] | ||
|
||
if serialized_tools is None: | ||
return | ||
|
||
if not isinstance(serialized_tools, list): | ||
raise TypeError(f"The value of '{key}' is not a list") | ||
|
||
deserialized_tools = [] | ||
for tool in serialized_tools: | ||
if not isinstance(tool, dict): | ||
raise TypeError(f"Serialized tool '{tool}' is not a dictionary") | ||
deserialized_tools.append(Tool.from_dict(tool)) | ||
|
||
data[key] = deserialized_tools |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
--- | ||
highlights: > | ||
We are introducing the `Tool` dataclass: a simple and unified abstraction to represent tools throughout the framework. | ||
By building on this abstraction, we will enable support for tools in Chat Generators, | ||
providing a consistent experience across models. | ||
features: | ||
- | | ||
Added a new `Tool` dataclass to represent a tool for which Language Models can prepare calls. |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Knowing that Tool is central piece of the Agents push, should be we make this dependency default? It's 80Kb binary
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not totally sure. Let's involve also @julian-risch in the decision.
jsonschema
: ci: Skip collection oftest_json_schema.py
to fix CI failures #7353There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems to be not occurring on 3.9 and after, but ok, let's get this integrated and then we can experiment with including it as a default dependency. Or not.