From 81be502defa6320d1e316b43126df14e9a53f7a4 Mon Sep 17 00:00:00 2001 From: "David S. Batista" Date: Tue, 24 Sep 2024 16:43:44 +0200 Subject: [PATCH 001/229] ci: adding github workflow for Snowflake (#1097) * Adding github workflow for Snowflake * Fix docs configs --------- Co-authored-by: Silvano Cerza --- .github/workflows/snowflake.yml | 74 +++++++++++++++++++++++++ integrations/snowflake/pydoc/config.yml | 4 +- 2 files changed, 76 insertions(+), 2 deletions(-) create mode 100644 .github/workflows/snowflake.yml diff --git a/.github/workflows/snowflake.yml b/.github/workflows/snowflake.yml new file mode 100644 index 000000000..19596f312 --- /dev/null +++ b/.github/workflows/snowflake.yml @@ -0,0 +1,74 @@ +# This workflow comes from https://github.com/ofek/hatch-mypyc +# https://github.com/ofek/hatch-mypyc/blob/5a198c0ba8660494d02716cfc9d79ce4adfb1442/.github/workflows/test.yml +name: Test / snowflake + +on: + schedule: + - cron: "0 0 * * *" + pull_request: + paths: + - "integrations/snowflake/**" + - ".github/workflows/snowflake.yml" + +defaults: + run: + working-directory: integrations/snowflake + +concurrency: + group: snowflake-${{ github.head_ref }} + cancel-in-progress: true + +env: + PYTHONUNBUFFERED: "1" + FORCE_COLOR: "1" + +jobs: + run: + name: Python ${{ matrix.python-version }} on ${{ startsWith(matrix.os, 'macos-') && 'macOS' || startsWith(matrix.os, 'windows-') && 'Windows' || 'Linux' }} + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + os: [ubuntu-latest, windows-latest, macos-latest] + python-version: ["3.9", "3.10"] + + steps: + - name: Support longpaths + if: matrix.os == 'windows-latest' + working-directory: . + run: git config --system core.longpaths true + + - uses: actions/checkout@v4 + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + + - name: Install Hatch + run: pip install --upgrade hatch + + - name: Lint + if: matrix.python-version == '3.9' && runner.os == 'Linux' + run: hatch run lint:all + + - name: Generate docs + if: matrix.python-version == '3.9' && runner.os == 'Linux' + run: hatch run docs + + - name: Run tests + run: hatch run cov-retry + + - name: Nightly - run unit tests with Haystack main branch + if: github.event_name == 'schedule' + run: | + hatch run pip install git+https://github.com/deepset-ai/haystack.git + hatch run cov-retry -m "not integration" + + - name: Send event to Datadog for nightly failures + if: failure() && github.event_name == 'schedule' + uses: ./.github/actions/send_failure + with: + title: | + Core integrations nightly tests failure: ${{ github.workflow }} + api-key: ${{ secrets.CORE_DATADOG_API_KEY }} diff --git a/integrations/snowflake/pydoc/config.yml b/integrations/snowflake/pydoc/config.yml index 7237b3816..3ebeaef51 100644 --- a/integrations/snowflake/pydoc/config.yml +++ b/integrations/snowflake/pydoc/config.yml @@ -3,7 +3,7 @@ loaders: search_path: [../src] modules: [ - "haystack_integrations.components.retrievers.snowflake.snowflake_retriever" + "haystack_integrations.components.retrievers.snowflake.snowflake_table_retriever", ] ignore_when_discovered: ["__init__"] processors: @@ -27,4 +27,4 @@ renderer: descriptive_module_title: true add_method_class_prefix: true add_member_class_prefix: false - filename: _readme_snowflake.md \ No newline at end of file + filename: _readme_snowflake.md From bca32be6d7ccc71b8e0ffbf575e8370fd909f477 Mon Sep 17 00:00:00 2001 From: Amna Mubashar Date: Tue, 24 Sep 2024 18:35:32 +0200 Subject: [PATCH 002/229] feat: Add chatrole tests and meta for GeminiChatGenerators (#1090) --- .../generators/google_ai/chat/gemini.py | 43 +++++++++---- .../tests/generators/chat/test_chat_gemini.py | 60 +++++++++++++++---- .../generators/google_vertex/chat/gemini.py | 41 ++++++++++--- .../google_vertex/tests/chat/test_gemini.py | 18 +++--- 4 files changed, 122 insertions(+), 40 deletions(-) diff --git a/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/chat/gemini.py b/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/chat/gemini.py index e859a29fd..56c84968b 100644 --- a/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/chat/gemini.py +++ b/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/chat/gemini.py @@ -311,17 +311,25 @@ def _get_response(self, response_body: GenerateContentResponse) -> List[ChatMess :param response_body: The response from Google AI request. :returns: The extracted responses. """ - replies = [] - for candidate in response_body.candidates: + replies: List[ChatMessage] = [] + metadata = response_body.to_dict() + for idx, candidate in enumerate(response_body.candidates): + candidate_metadata = metadata["candidates"][idx] + candidate_metadata.pop("content", None) # we remove content from the metadata + for part in candidate.content.parts: if part.text != "": - replies.append(ChatMessage.from_assistant(part.text)) - elif part.function_call is not None: + replies.append( + ChatMessage(content=part.text, role=ChatRole.ASSISTANT, name=None, meta=candidate_metadata) + ) + elif part.function_call: + candidate_metadata["function_call"] = part.function_call replies.append( ChatMessage( content=dict(part.function_call.args.items()), role=ChatRole.ASSISTANT, name=part.function_call.name, + meta=candidate_metadata, ) ) return replies @@ -336,11 +344,26 @@ def _get_stream_response( :param streaming_callback: The handler for the streaming response. :returns: The extracted response with the content of all streaming chunks. """ - responses = [] + replies: List[ChatMessage] = [] for chunk in stream: - content = chunk.text if len(chunk.parts) > 0 and "text" in chunk.parts[0] else "" - streaming_callback(StreamingChunk(content=content, meta=chunk.to_dict())) - responses.append(content) + content: Union[str, Dict[str, Any]] = "" + metadata = chunk.to_dict() # we store whole chunk as metadata in streaming calls + for candidate in chunk.candidates: + for part in candidate.content.parts: + if part.text != "": + content = part.text + replies.append(ChatMessage(content=content, role=ChatRole.ASSISTANT, meta=metadata, name=None)) + elif part.function_call is not None: + metadata["function_call"] = part.function_call + content = dict(part.function_call.args.items()) + replies.append( + ChatMessage( + content=content, + role=ChatRole.ASSISTANT, + name=part.function_call.name, + meta=metadata, + ) + ) - combined_response = "".join(responses).lstrip() - return [ChatMessage.from_assistant(content=combined_response)] + streaming_callback(StreamingChunk(content=content, meta=metadata)) + return replies diff --git a/integrations/google_ai/tests/generators/chat/test_chat_gemini.py b/integrations/google_ai/tests/generators/chat/test_chat_gemini.py index 35ad8db14..c4372db0d 100644 --- a/integrations/google_ai/tests/generators/chat/test_chat_gemini.py +++ b/integrations/google_ai/tests/generators/chat/test_chat_gemini.py @@ -5,7 +5,7 @@ from google.generativeai import GenerationConfig, GenerativeModel from google.generativeai.types import FunctionDeclaration, HarmBlockThreshold, HarmCategory, Tool from haystack.dataclasses import StreamingChunk -from haystack.dataclasses.chat_message import ChatMessage +from haystack.dataclasses.chat_message import ChatMessage, ChatRole from haystack_integrations.components.generators.google_ai import GoogleAIGeminiChatGenerator @@ -207,7 +207,7 @@ def get_current_weather(location: str, unit: str = "celsius"): # noqa: ARG001 get_current_weather_func = FunctionDeclaration.from_function( get_current_weather, descriptions={ - "location": "The city and state, e.g. San Francisco, CA", + "location": "The city, e.g. San Francisco", "unit": "The temperature unit of measurement, e.g. celsius or fahrenheit", }, ) @@ -215,14 +215,27 @@ def get_current_weather(location: str, unit: str = "celsius"): # noqa: ARG001 tool = Tool(function_declarations=[get_current_weather_func]) gemini_chat = GoogleAIGeminiChatGenerator(model="gemini-pro", tools=[tool]) messages = [ChatMessage.from_user(content="What is the temperature in celsius in Berlin?")] - res = gemini_chat.run(messages=messages) - assert len(res["replies"]) > 0 + response = gemini_chat.run(messages=messages) + assert "replies" in response + assert len(response["replies"]) > 0 + assert all(reply.role == ChatRole.ASSISTANT for reply in response["replies"]) - weather = get_current_weather(**res["replies"][0].content) - messages += res["replies"] + [ChatMessage.from_function(content=weather, name="get_current_weather")] + # check the first response is a function call + chat_message = response["replies"][0] + assert "function_call" in chat_message.meta + assert chat_message.content == {"location": "Berlin", "unit": "celsius"} - res = gemini_chat.run(messages=messages) - assert len(res["replies"]) > 0 + weather = get_current_weather(**chat_message.content) + messages += response["replies"] + [ChatMessage.from_function(content=weather, name="get_current_weather")] + response = gemini_chat.run(messages=messages) + assert "replies" in response + assert len(response["replies"]) > 0 + assert all(reply.role == ChatRole.ASSISTANT for reply in response["replies"]) + + # check the second response is not a function call + chat_message = response["replies"][0] + assert "function_call" not in chat_message.meta + assert isinstance(chat_message.content, str) @pytest.mark.skipif(not os.environ.get("GOOGLE_API_KEY", None), reason="GOOGLE_API_KEY env var not set") @@ -239,7 +252,7 @@ def get_current_weather(location: str, unit: str = "celsius"): # noqa: ARG001 get_current_weather_func = FunctionDeclaration.from_function( get_current_weather, descriptions={ - "location": "The city and state, e.g. San Francisco, CA", + "location": "The city, e.g. San Francisco", "unit": "The temperature unit of measurement, e.g. celsius or fahrenheit", }, ) @@ -247,10 +260,29 @@ def get_current_weather(location: str, unit: str = "celsius"): # noqa: ARG001 tool = Tool(function_declarations=[get_current_weather_func]) gemini_chat = GoogleAIGeminiChatGenerator(model="gemini-pro", tools=[tool], streaming_callback=streaming_callback) messages = [ChatMessage.from_user(content="What is the temperature in celsius in Berlin?")] - res = gemini_chat.run(messages=messages) - assert len(res["replies"]) > 0 + response = gemini_chat.run(messages=messages) + assert "replies" in response + assert len(response["replies"]) > 0 + assert all(reply.role == ChatRole.ASSISTANT for reply in response["replies"]) assert streaming_callback_called + # check the first response is a function call + chat_message = response["replies"][0] + assert "function_call" in chat_message.meta + assert chat_message.content == {"location": "Berlin", "unit": "celsius"} + + weather = get_current_weather(**response["replies"][0].content) + messages += response["replies"] + [ChatMessage.from_function(content=weather, name="get_current_weather")] + response = gemini_chat.run(messages=messages) + assert "replies" in response + assert len(response["replies"]) > 0 + assert all(reply.role == ChatRole.ASSISTANT for reply in response["replies"]) + + # check the second response is not a function call + chat_message = response["replies"][0] + assert "function_call" not in chat_message.meta + assert isinstance(chat_message.content, str) + @pytest.mark.skipif(not os.environ.get("GOOGLE_API_KEY", None), reason="GOOGLE_API_KEY env var not set") def test_past_conversation(): @@ -261,5 +293,7 @@ def test_past_conversation(): ChatMessage.from_assistant(content="It's an arithmetic operation."), ChatMessage.from_user(content="Yeah, but what's the result?"), ] - res = gemini_chat.run(messages=messages) - assert len(res["replies"]) > 0 + response = gemini_chat.run(messages=messages) + assert "replies" in response + assert len(response["replies"]) > 0 + assert all(reply.role == ChatRole.ASSISTANT for reply in response["replies"]) diff --git a/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/chat/gemini.py b/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/chat/gemini.py index e693c10f4..ac4c93228 100644 --- a/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/chat/gemini.py +++ b/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/chat/gemini.py @@ -229,17 +229,24 @@ def _get_response(self, response_body: GenerationResponse) -> List[ChatMessage]: :param response_body: The response from Vertex AI request. :returns: The extracted responses. """ - replies = [] + replies: List[ChatMessage] = [] for candidate in response_body.candidates: + metadata = candidate.to_dict() for part in candidate.content.parts: + # Remove content from metadata + metadata.pop("content", None) if part._raw_part.text != "": - replies.append(ChatMessage.from_assistant(part.text)) - elif part.function_call is not None: + replies.append( + ChatMessage(content=part._raw_part.text, role=ChatRole.ASSISTANT, name=None, meta=metadata) + ) + elif part.function_call: + metadata["function_call"] = part.function_call replies.append( ChatMessage( content=dict(part.function_call.args.items()), role=ChatRole.ASSISTANT, name=part.function_call.name, + meta=metadata, ) ) return replies @@ -254,11 +261,27 @@ def _get_stream_response( :param streaming_callback: The handler for the streaming response. :returns: The extracted response with the content of all streaming chunks. """ - responses = [] + replies: List[ChatMessage] = [] + for chunk in stream: - streaming_chunk = StreamingChunk(content=chunk.text, meta=chunk.to_dict()) - streaming_callback(streaming_chunk) - responses.append(streaming_chunk.content) + content: Union[str, Dict[str, Any]] = "" + metadata = chunk.to_dict() # we store whole chunk as metadata for streaming + for candidate in chunk.candidates: + for part in candidate.content.parts: + if part._raw_part.text: + content = chunk.text + replies.append(ChatMessage(content, role=ChatRole.ASSISTANT, name=None, meta=metadata)) + elif part.function_call: + metadata["function_call"] = part.function_call + content = dict(part.function_call.args.items()) + replies.append( + ChatMessage( + content=content, + role=ChatRole.ASSISTANT, + name=part.function_call.name, + meta=metadata, + ) + ) + streaming_callback(StreamingChunk(content=content, meta=metadata)) - combined_response = "".join(responses).lstrip() - return [ChatMessage.from_assistant(content=combined_response)] + return replies diff --git a/integrations/google_vertex/tests/chat/test_gemini.py b/integrations/google_vertex/tests/chat/test_gemini.py index a1564b9f2..ab21008fb 100644 --- a/integrations/google_vertex/tests/chat/test_gemini.py +++ b/integrations/google_vertex/tests/chat/test_gemini.py @@ -3,7 +3,7 @@ import pytest from haystack import Pipeline from haystack.components.builders import ChatPromptBuilder -from haystack.dataclasses import ChatMessage, StreamingChunk +from haystack.dataclasses import ChatMessage, ChatRole, StreamingChunk from vertexai.generative_models import ( Content, FunctionDeclaration, @@ -249,9 +249,12 @@ def test_run(mock_generative_model): ChatMessage.from_user("What's the capital of France?"), ] gemini = VertexAIGeminiChatGenerator(project_id="TestID123", location=None) - gemini.run(messages=messages) + response = gemini.run(messages=messages) mock_model.send_message.assert_called_once() + assert "replies" in response + assert len(response["replies"]) > 0 + assert all(reply.role == ChatRole.ASSISTANT for reply in response["replies"]) @patch("haystack_integrations.components.generators.google_vertex.chat.gemini.GenerativeModel") @@ -260,25 +263,24 @@ def test_run_with_streaming_callback(mock_generative_model): mock_responses = iter( [MagicMock(spec=GenerationResponse, text="First part"), MagicMock(spec=GenerationResponse, text="Second part")] ) - mock_model.send_message.return_value = mock_responses mock_model.start_chat.return_value = mock_model mock_generative_model.return_value = mock_model streaming_callback_called = [] - def streaming_callback(chunk: StreamingChunk) -> None: - streaming_callback_called.append(chunk.content) + def streaming_callback(_chunk: StreamingChunk) -> None: + nonlocal streaming_callback_called + streaming_callback_called = True gemini = VertexAIGeminiChatGenerator(project_id="TestID123", location=None, streaming_callback=streaming_callback) messages = [ ChatMessage.from_system("You are a helpful assistant"), ChatMessage.from_user("What's the capital of France?"), ] - gemini.run(messages=messages) - + response = gemini.run(messages=messages) mock_model.send_message.assert_called_once() - assert streaming_callback_called == ["First part", "Second part"] + assert "replies" in response def test_serialization_deserialization_pipeline(): From 3be78821a89f1ddf8ce60a29a6a29a4105c36261 Mon Sep 17 00:00:00 2001 From: Alper Date: Wed, 25 Sep 2024 10:01:51 +0200 Subject: [PATCH 003/229] feat: Chroma - allow remote HTTP connection (#1094) * add http client * remove old code * update order * Apply suggestions from code review Co-authored-by: Stefano Fiorucci * add testcases * run chroma db in the bg * fix line too long * fix testcase * support chroma bg on windows * fix chroma on win * chroma fix for powershell on win * simplification * fix wrong skipif * linting --------- Co-authored-by: Stefano Fiorucci --- .github/workflows/chroma.yml | 4 ++ .../document_stores/chroma/document_store.py | 32 ++++++++++++-- .../chroma/tests/test_document_store.py | 44 +++++++++++++++++-- integrations/chroma/tests/test_retriever.py | 2 + 4 files changed, 76 insertions(+), 6 deletions(-) diff --git a/.github/workflows/chroma.yml b/.github/workflows/chroma.yml index 26b6287bd..6dbf36d85 100644 --- a/.github/workflows/chroma.yml +++ b/.github/workflows/chroma.yml @@ -56,6 +56,10 @@ jobs: if: matrix.python-version == '3.9' && runner.os == 'Linux' run: hatch run docs + - name: Run Chroma server on Linux/macOS + if: matrix.os != 'windows-latest' + run: hatch run chroma run & + - name: Run tests run: hatch run cov-retry diff --git a/integrations/chroma/src/haystack_integrations/document_stores/chroma/document_store.py b/integrations/chroma/src/haystack_integrations/document_stores/chroma/document_store.py index 49729c0d4..359ace58d 100644 --- a/integrations/chroma/src/haystack_integrations/document_stores/chroma/document_store.py +++ b/integrations/chroma/src/haystack_integrations/document_stores/chroma/document_store.py @@ -33,6 +33,8 @@ def __init__( collection_name: str = "documents", embedding_function: str = "default", persist_path: Optional[str] = None, + host: Optional[str] = None, + port: Optional[int] = None, distance_function: Literal["l2", "cosine", "ip"] = "l2", metadata: Optional[dict] = None, **embedding_function_params, @@ -48,7 +50,10 @@ def __init__( :param collection_name: the name of the collection to use in the database. :param embedding_function: the name of the embedding function to use to embed the query - :param persist_path: where to store the database. If None, the database will be `in-memory`. + :param persist_path: Path for local persistent storage. Cannot be used in combination with `host` and `port`. + If none of `persist_path`, `host`, and `port` is specified, the database will be `in-memory`. + :param host: The host address for the remote Chroma HTTP client connection. Cannot be used with `persist_path`. + :param port: The port number for the remote Chroma HTTP client connection. Cannot be used with `persist_path`. :param distance_function: The distance metric for the embedding space. - `"l2"` computes the Euclidean (straight-line) distance between vectors, where smaller scores indicate more similarity. @@ -75,12 +80,31 @@ def __init__( self._collection_name = collection_name self._embedding_function = embedding_function self._embedding_function_params = embedding_function_params - self._persist_path = persist_path self._distance_function = distance_function + + self._persist_path = persist_path + self._host = host + self._port = port + # Create the client instance - if persist_path is None: + if persist_path and (host or port is not None): + error_message = ( + "You must specify `persist_path` for local persistent storage or, " + "alternatively, `host` and `port` for remote HTTP client connection. " + "You cannot specify both options." + ) + raise ValueError(error_message) + if host and port is not None: + # Remote connection via HTTP client + self._chroma_client = chromadb.HttpClient( + host=host, + port=port, + ) + elif persist_path is None: + # In-memory storage self._chroma_client = chromadb.Client() else: + # Local persistent storage self._chroma_client = chromadb.PersistentClient(path=persist_path) embedding_func = get_embedding_function(embedding_function, **embedding_function_params) @@ -341,6 +365,8 @@ def to_dict(self) -> Dict[str, Any]: collection_name=self._collection_name, embedding_function=self._embedding_function, persist_path=self._persist_path, + host=self._host, + port=self._port, distance_function=self._distance_function, **self._embedding_function_params, ) diff --git a/integrations/chroma/tests/test_document_store.py b/integrations/chroma/tests/test_document_store.py index cd20bd398..3a6952ff8 100644 --- a/integrations/chroma/tests/test_document_store.py +++ b/integrations/chroma/tests/test_document_store.py @@ -6,6 +6,7 @@ import uuid from typing import List from unittest import mock +import sys import numpy as np import pytest @@ -66,6 +67,39 @@ def assert_documents_are_equal(self, received: List[Document], expected: List[Do assert doc_received.content == doc_expected.content assert doc_received.meta == doc_expected.meta + def test_init_in_memory(self): + store = ChromaDocumentStore() + + assert store._persist_path is None + assert store._host is None + assert store._port is None + + def test_init_persistent_storage(self): + store = ChromaDocumentStore(persist_path="./path/to/local/store") + + assert store._persist_path == "./path/to/local/store" + assert store._host is None + assert store._port is None + + @pytest.mark.integration + @pytest.mark.skipif( + sys.platform == "win32", + reason="This test requires running the Chroma server. For simplicity, we don't run it on Windows.", + ) + def test_init_http_connection(self): + store = ChromaDocumentStore(host="localhost", port=8000) + + assert store._persist_path is None + assert store._host == "localhost" + assert store._port == 8000 + + def test_invalid_initialization_both_host_and_persist_path(self): + """ + Test that providing both host and persist_path raises an error. + """ + with pytest.raises(ValueError): + ChromaDocumentStore(persist_path="./path/to/local/store", host="localhost") + def test_delete_empty(self, document_store: ChromaDocumentStore): """ Deleting a non-existing document should not raise with Chroma @@ -125,7 +159,7 @@ def test_write_documents_unsupported_meta_values(self, document_store: ChromaDoc assert written_docs[2].meta == {"ok": 123} @pytest.mark.integration - def test_to_json(self, request): + def test_to_dict(self, request): ds = ChromaDocumentStore( collection_name=request.node.name, embedding_function="HuggingFaceEmbeddingFunction", api_key="1234567890" ) @@ -133,16 +167,18 @@ def test_to_json(self, request): assert ds_dict == { "type": "haystack_integrations.document_stores.chroma.document_store.ChromaDocumentStore", "init_parameters": { - "collection_name": "test_to_json", + "collection_name": "test_to_dict", "embedding_function": "HuggingFaceEmbeddingFunction", "persist_path": None, + "host": None, + "port": None, "api_key": "1234567890", "distance_function": "l2", }, } @pytest.mark.integration - def test_from_json(self): + def test_from_dict(self): collection_name = "test_collection" function_name = "HuggingFaceEmbeddingFunction" ds_dict = { @@ -151,6 +187,8 @@ def test_from_json(self): "collection_name": "test_collection", "embedding_function": "HuggingFaceEmbeddingFunction", "persist_path": None, + "host": None, + "port": None, "api_key": "1234567890", "distance_function": "l2", }, diff --git a/integrations/chroma/tests/test_retriever.py b/integrations/chroma/tests/test_retriever.py index 99a4c34e6..645360033 100644 --- a/integrations/chroma/tests/test_retriever.py +++ b/integrations/chroma/tests/test_retriever.py @@ -38,6 +38,8 @@ def test_retriever_to_json(request): "collection_name": "test_retriever_to_json", "embedding_function": "HuggingFaceEmbeddingFunction", "persist_path": None, + "host": None, + "port": None, "api_key": "1234567890", "distance_function": "l2", }, From 7ee0a8173dddf4359a4201f20836b9898316beaa Mon Sep 17 00:00:00 2001 From: Daria Fokina Date: Wed, 25 Sep 2024 17:10:21 +0200 Subject: [PATCH 004/229] upd snowflake pydoc (#1102) --- integrations/snowflake/pydoc/config.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/integrations/snowflake/pydoc/config.yml b/integrations/snowflake/pydoc/config.yml index 3ebeaef51..9c03ff45a 100644 --- a/integrations/snowflake/pydoc/config.yml +++ b/integrations/snowflake/pydoc/config.yml @@ -19,8 +19,8 @@ renderer: excerpt: Snowflake integration for Haystack category_slug: integrations-api title: Snowflake - slug: integrations-Snowflake - order: 130 + slug: integrations-snowflake + order: 225 markdown: descriptive_class_title: false classdef_code_block: false From a7bb39e17922f88bb0d6070eba5360e3456ddb1e Mon Sep 17 00:00:00 2001 From: HaystackBot Date: Wed, 25 Sep 2024 15:55:20 +0000 Subject: [PATCH 005/229] Update the changelog --- integrations/snowflake/CHANGELOG.md | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/integrations/snowflake/CHANGELOG.md b/integrations/snowflake/CHANGELOG.md index 0553a3f4b..757bfb3fe 100644 --- a/integrations/snowflake/CHANGELOG.md +++ b/integrations/snowflake/CHANGELOG.md @@ -1 +1,13 @@ -## [integrations/snowflake-v0.0.1] - 2024-09-06 \ No newline at end of file +# Changelog + +## [integrations/snowflake-v0.0.2] - 2024-09-25 + +### 🚀 Features + +- Add Snowflake integration (#1064) + +### ⚙️ Miscellaneous Tasks + +- Adding github workflow for Snowflake (#1097) + + From ac55f19ba61f8fb0e8c711bbc6c9139d781d595c Mon Sep 17 00:00:00 2001 From: "David S. Batista" Date: Thu, 26 Sep 2024 10:06:32 +0200 Subject: [PATCH 006/229] adding snowflake to the labeler and README.MD (#1104) --- .github/labeler.yml | 5 +++++ README.md | 3 ++- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/.github/labeler.yml b/.github/labeler.yml index d8bb71098..85f15788f 100644 --- a/.github/labeler.yml +++ b/.github/labeler.yml @@ -124,6 +124,11 @@ integration:weaviate: - any-glob-to-any-file: "integrations/weaviate/**/*" - any-glob-to-any-file: ".github/workflows/weaviate.yml" +integration:snowflake: + - changed-files: + - any-glob-to-any-file: "integrations/snowflake/**/*" + - any-glob-to-any-file: ".github/workflows/snowflake.yml" + integration:deepeval: - changed-files: - any-glob-to-any-file: "integrations/deepeval/**/*" diff --git a/README.md b/README.md index 7ba853d62..d80eff62c 100644 --- a/README.md +++ b/README.md @@ -52,7 +52,8 @@ Please check out our [Contribution Guidelines](CONTRIBUTING.md) for all the deta | [qdrant-haystack](integrations/qdrant/) | Document Store | [![PyPI - Version](https://img.shields.io/pypi/v/qdrant-haystack.svg?color=orange)](https://pypi.org/project/qdrant-haystack) | [![Test / qdrant](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/qdrant.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/qdrant.yml) | | [ragas-haystack](integrations/ragas/) | Evaluator | [![PyPI - Version](https://img.shields.io/pypi/v/ragas-haystack.svg)](https://pypi.org/project/ragas-haystack) | [![Test / ragas](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/ragas.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/ragas.yml) | | [unstructured-fileconverter-haystack](integrations/unstructured/) | File converter | [![PyPI - Version](https://img.shields.io/pypi/v/unstructured-fileconverter-haystack.svg)](https://pypi.org/project/unstructured-fileconverter-haystack) | [![Test / unstructured](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/unstructured.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/unstructured.yml) | -| [uptrain-haystack](https://github.com/deepset-ai/haystack-core-integrations/tree/staging/integrations/uptrain) | Evaluator | [![PyPI - Version](https://img.shields.io/pypi/v/uptrain-haystack.svg)](https://pypi.org/project/uptrain-haystack) | [Staged](https://docs.haystack.deepset.ai/docs/breaking-change-policy#discontinuing-an-integration) | +| [uptrain-haystack](https://github.com/deepset-ai/haystack-core-integrations/tree/staging/integrations/uptrain) | Evaluator | [![PyPI - Version](https://img.shields.io/pypi/v/uptrain-haystack.svg)](https://pypi.org/project/uptrain-haystack) | [Staged](https://docs.haystack.deepset.ai/docs/breaking-change-policy#discontinuing-an-integration) | +| [snowflake-haystack](integrations/Snowflake/) | Document Store | [![PyPI - Version](https://img.shields.io/pypi/v/snowflake-haystack.svg)](https://pypi.org/project/snowflake-haystack) | [![Test / weaviate](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/snowflake.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/snowflake.yml) | | [weaviate-haystack](integrations/weaviate/) | Document Store | [![PyPI - Version](https://img.shields.io/pypi/v/weaviate-haystack.svg)](https://pypi.org/project/weaviate-haystack) | [![Test / weaviate](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/weaviate.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/weaviate.yml) | ## Releasing From 8004e7d11109fe0df85543453bac604f324366f9 Mon Sep 17 00:00:00 2001 From: brendancicchi Date: Thu, 26 Sep 2024 05:33:38 -0400 Subject: [PATCH 007/229] #1047 Remove count_documents from delete_documents (#1049) Removed the expensive check to see if the collection is non-empty by performing a full count. This is to fix issue #1047 Co-authored-by: Vladimir Blagojevic --- .../document_stores/astra/astra_client.py | 18 ++++++++++++++++++ .../document_stores/astra/document_store.py | 4 ++-- 2 files changed, 20 insertions(+), 2 deletions(-) diff --git a/integrations/astra/src/haystack_integrations/document_stores/astra/astra_client.py b/integrations/astra/src/haystack_integrations/document_stores/astra/astra_client.py index 5a88a0fe9..b594f87d3 100644 --- a/integrations/astra/src/haystack_integrations/document_stores/astra/astra_client.py +++ b/integrations/astra/src/haystack_integrations/document_stores/astra/astra_client.py @@ -231,6 +231,24 @@ def find_documents(self, find_query): else: logger.warning(f"No documents found: {response_dict}") + def find_one_document(self, find_query): + """ + Find one document in the Astra index. + + :param find_query: a dictionary with the query options + :returns: the document found in the index + """ + response_dict = self._astra_db_collection.find_one( + filter=find_query.get("filter"), + options=find_query.get("options"), + projection={"*": 1}, + ) + + if "data" in response_dict and "document" in response_dict["data"]: + return response_dict["data"]["document"] + else: + logger.warning(f"No document found: {response_dict}") + def get_documents(self, ids: List[str], batch_size: int = 20) -> QueryResponse: """ Get documents from the Astra index by their ids. diff --git a/integrations/astra/src/haystack_integrations/document_stores/astra/document_store.py b/integrations/astra/src/haystack_integrations/document_stores/astra/document_store.py index 1dea6e08b..a7a7a231c 100644 --- a/integrations/astra/src/haystack_integrations/document_stores/astra/document_store.py +++ b/integrations/astra/src/haystack_integrations/document_stores/astra/document_store.py @@ -411,8 +411,8 @@ def delete_documents(self, document_ids: Optional[List[str]] = None, delete_all: :param delete_all: if `True`, delete all documents. :raises MissingDocumentError: if no document was deleted but document IDs were provided. """ - deletion_counter = 0 - if self.index.count_documents() > 0: + if self.index.find_one_document({"filter": {}}) is not None: + deletion_counter = 0 if document_ids is not None: for batch in _batches(document_ids, MAX_BATCH_SIZE): deletion_counter += self.index.delete(ids=batch) From 0f560c241c94da0a3874475e82b623dea77021fa Mon Sep 17 00:00:00 2001 From: "David S. Batista" Date: Thu, 26 Sep 2024 11:42:02 +0200 Subject: [PATCH 008/229] fix: fixing README (#1109) * fixing README * Update README.md Co-authored-by: Stefano Fiorucci --------- Co-authored-by: Stefano Fiorucci --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index d80eff62c..2b4a83253 100644 --- a/README.md +++ b/README.md @@ -51,9 +51,9 @@ Please check out our [Contribution Guidelines](CONTRIBUTING.md) for all the deta | [pgvector-haystack](integrations/pgvector/) | Document Store | [![PyPI - Version](https://img.shields.io/pypi/v/pgvector-haystack.svg?color=orange)](https://pypi.org/project/pgvector-haystack) | [![Test / pgvector](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/pgvector.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/pgvector.yml) | | [qdrant-haystack](integrations/qdrant/) | Document Store | [![PyPI - Version](https://img.shields.io/pypi/v/qdrant-haystack.svg?color=orange)](https://pypi.org/project/qdrant-haystack) | [![Test / qdrant](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/qdrant.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/qdrant.yml) | | [ragas-haystack](integrations/ragas/) | Evaluator | [![PyPI - Version](https://img.shields.io/pypi/v/ragas-haystack.svg)](https://pypi.org/project/ragas-haystack) | [![Test / ragas](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/ragas.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/ragas.yml) | +| [snowflake-haystack](integrations/snowflake/) | Retriever | [![PyPI - Version](https://img.shields.io/pypi/v/snowflake-haystack.svg)](https://pypi.org/project/snowflake-haystack) | [![Test / snowflake](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/snowflake.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/snowflake.yml) | | [unstructured-fileconverter-haystack](integrations/unstructured/) | File converter | [![PyPI - Version](https://img.shields.io/pypi/v/unstructured-fileconverter-haystack.svg)](https://pypi.org/project/unstructured-fileconverter-haystack) | [![Test / unstructured](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/unstructured.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/unstructured.yml) | | [uptrain-haystack](https://github.com/deepset-ai/haystack-core-integrations/tree/staging/integrations/uptrain) | Evaluator | [![PyPI - Version](https://img.shields.io/pypi/v/uptrain-haystack.svg)](https://pypi.org/project/uptrain-haystack) | [Staged](https://docs.haystack.deepset.ai/docs/breaking-change-policy#discontinuing-an-integration) | -| [snowflake-haystack](integrations/Snowflake/) | Document Store | [![PyPI - Version](https://img.shields.io/pypi/v/snowflake-haystack.svg)](https://pypi.org/project/snowflake-haystack) | [![Test / weaviate](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/snowflake.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/snowflake.yml) | | [weaviate-haystack](integrations/weaviate/) | Document Store | [![PyPI - Version](https://img.shields.io/pypi/v/weaviate-haystack.svg)](https://pypi.org/project/weaviate-haystack) | [![Test / weaviate](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/weaviate.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/weaviate.yml) | ## Releasing From dee3e774200257c111603ca78be2c8f8d9b763a1 Mon Sep 17 00:00:00 2001 From: Stefano Fiorucci Date: Thu, 26 Sep 2024 15:54:43 +0200 Subject: [PATCH 009/229] fix: Ollama Chat Generator - add missing `to_dict` and `from_dict` methods (#1110) * add missing to_dict/from_dict and tests * linting --- .../generators/ollama/chat/chat_generator.py | 36 +++++++++++++++++- .../ollama/tests/test_chat_generator.py | 37 +++++++++++++++++++ 2 files changed, 72 insertions(+), 1 deletion(-) diff --git a/integrations/ollama/src/haystack_integrations/components/generators/ollama/chat/chat_generator.py b/integrations/ollama/src/haystack_integrations/components/generators/ollama/chat/chat_generator.py index 1f3a0bf1e..9502a187e 100644 --- a/integrations/ollama/src/haystack_integrations/components/generators/ollama/chat/chat_generator.py +++ b/integrations/ollama/src/haystack_integrations/components/generators/ollama/chat/chat_generator.py @@ -1,7 +1,8 @@ from typing import Any, Callable, Dict, List, Optional -from haystack import component +from haystack import component, default_from_dict, default_to_dict from haystack.dataclasses import ChatMessage, StreamingChunk +from haystack.utils.callable_serialization import deserialize_callable, serialize_callable from ollama import Client @@ -63,6 +64,39 @@ def __init__( self._client = Client(host=self.url, timeout=self.timeout) + def to_dict(self) -> Dict[str, Any]: + """ + Serializes the component to a dictionary. + + :returns: + Dictionary with serialized data. + """ + callback_name = serialize_callable(self.streaming_callback) if self.streaming_callback else None + return default_to_dict( + self, + model=self.model, + url=self.url, + generation_kwargs=self.generation_kwargs, + timeout=self.timeout, + streaming_callback=callback_name, + ) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "OllamaChatGenerator": + """ + Deserializes the component from a dictionary. + + :param data: + Dictionary to deserialize from. + :returns: + Deserialized component. + """ + init_params = data.get("init_parameters", {}) + serialized_callback_handler = init_params.get("streaming_callback") + if serialized_callback_handler: + data["init_parameters"]["streaming_callback"] = deserialize_callable(serialized_callback_handler) + return default_from_dict(cls, data) + def _message_to_dict(self, message: ChatMessage) -> Dict[str, str]: return {"role": message.role.value, "content": message.content} diff --git a/integrations/ollama/tests/test_chat_generator.py b/integrations/ollama/tests/test_chat_generator.py index 79d70675a..a46758df3 100644 --- a/integrations/ollama/tests/test_chat_generator.py +++ b/integrations/ollama/tests/test_chat_generator.py @@ -2,6 +2,7 @@ from unittest.mock import Mock import pytest +from haystack.components.generators.utils import print_streaming_chunk from haystack.dataclasses import ChatMessage, ChatRole from ollama._types import ResponseError @@ -39,6 +40,42 @@ def test_init(self): assert component.generation_kwargs == {"temperature": 0.5} assert component.timeout == 5 + def test_to_dict(self): + component = OllamaChatGenerator( + model="llama2", + streaming_callback=print_streaming_chunk, + url="custom_url", + generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"}, + ) + data = component.to_dict() + assert data == { + "type": "haystack_integrations.components.generators.ollama.chat.chat_generator.OllamaChatGenerator", + "init_parameters": { + "timeout": 120, + "model": "llama2", + "url": "custom_url", + "streaming_callback": "haystack.components.generators.utils.print_streaming_chunk", + "generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"}, + }, + } + + def test_from_dict(self): + data = { + "type": "haystack_integrations.components.generators.ollama.chat.chat_generator.OllamaChatGenerator", + "init_parameters": { + "timeout": 120, + "model": "llama2", + "url": "custom_url", + "streaming_callback": "haystack.components.generators.utils.print_streaming_chunk", + "generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"}, + }, + } + component = OllamaChatGenerator.from_dict(data) + assert component.model == "llama2" + assert component.streaming_callback is print_streaming_chunk + assert component.url == "custom_url" + assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"} + def test_build_message_from_ollama_response(self): model = "some_model" From 7defbe9695b7cd6ff7614dd74c48486b435c0d3b Mon Sep 17 00:00:00 2001 From: HaystackBot Date: Thu, 26 Sep 2024 13:56:45 +0000 Subject: [PATCH 010/229] Update the changelog --- integrations/ollama/CHANGELOG.md | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/integrations/ollama/CHANGELOG.md b/integrations/ollama/CHANGELOG.md index 8f51237e9..5da725e87 100644 --- a/integrations/ollama/CHANGELOG.md +++ b/integrations/ollama/CHANGELOG.md @@ -1,5 +1,11 @@ # Changelog +## [integrations/ollama-v1.0.1] - 2024-09-26 + +### 🐛 Bug Fixes + +- Ollama Chat Generator - add missing `to_dict` and `from_dict` methods (#1110) + ## [integrations/ollama-v1.0.0] - 2024-09-07 ### 🐛 Bug Fixes From 46bc945411f5f8be1a82708392a547b889bb241a Mon Sep 17 00:00:00 2001 From: Amna Mubashar Date: Thu, 26 Sep 2024 16:46:45 +0200 Subject: [PATCH 011/229] feat: add custom params to VertexAIGeminiGenerator and VertexAIGeminiChatGenerator (#1100) * Added "tool_config" and "system_instruction" params --- .../generators/google_vertex/chat/gemini.py | 53 +++++++++++++++++-- .../generators/google_vertex/gemini.py | 52 ++++++++++++++++-- .../google_vertex/tests/chat/test_gemini.py | 42 +++++++++++++++ .../google_vertex/tests/test_gemini.py | 44 +++++++++++++++ 4 files changed, 182 insertions(+), 9 deletions(-) diff --git a/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/chat/gemini.py b/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/chat/gemini.py index ac4c93228..f09692daf 100644 --- a/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/chat/gemini.py +++ b/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/chat/gemini.py @@ -17,6 +17,7 @@ HarmCategory, Part, Tool, + ToolConfig, ) logger = logging.getLogger(__name__) @@ -54,6 +55,8 @@ def __init__( generation_config: Optional[Union[GenerationConfig, Dict[str, Any]]] = None, safety_settings: Optional[Dict[HarmCategory, HarmBlockThreshold]] = None, tools: Optional[List[Tool]] = None, + tool_config: Optional[ToolConfig] = None, + system_instruction: Optional[Union[str, ByteStream, Part]] = None, streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, ): """ @@ -76,8 +79,11 @@ def __init__( :param tools: List of tools to use when generating content. See the documentation for [Tool](https://cloud.google.com/python/docs/reference/aiplatform/latest/vertexai.generative_models.Tool) the list of supported arguments. + :param tool_config: The tool config to use. See the documentation for [ToolConfig] + (https://cloud.google.com/vertex-ai/generative-ai/docs/reference/python/latest/vertexai.generative_models.ToolConfig) + :param system_instruction: Default system instruction to use for generating content. :param streaming_callback: A callback function that is called when a new token is received from - the stream. The callback function accepts StreamingChunk as an argument. + the stream. The callback function accepts StreamingChunk as an argument. """ @@ -87,13 +93,25 @@ def __init__( self._model_name = model self._project_id = project_id self._location = location - self._model = GenerativeModel(self._model_name) + # model parameters self._generation_config = generation_config self._safety_settings = safety_settings self._tools = tools + self._tool_config = tool_config + self._system_instruction = system_instruction self._streaming_callback = streaming_callback + # except streaming_callback, all other model parameters can be passed during initialization + self._model = GenerativeModel( + self._model_name, + generation_config=self._generation_config, + safety_settings=self._safety_settings, + tools=self._tools, + tool_config=self._tool_config, + system_instruction=self._system_instruction, + ) + def _generation_config_to_dict(self, config: Union[GenerationConfig, Dict[str, Any]]) -> Dict[str, Any]: if isinstance(config, dict): return config @@ -106,6 +124,17 @@ def _generation_config_to_dict(self, config: Union[GenerationConfig, Dict[str, A "stop_sequences": config._raw_generation_config.stop_sequences, } + def _tool_config_to_dict(self, tool_config: ToolConfig) -> Dict[str, Any]: + """Serializes the ToolConfig object into a dictionary.""" + mode = tool_config._gapic_tool_config.function_calling_config.mode + allowed_function_names = tool_config._gapic_tool_config.function_calling_config.allowed_function_names + config_dict = {"function_calling_config": {"mode": mode}} + + if allowed_function_names: + config_dict["function_calling_config"]["allowed_function_names"] = allowed_function_names + + return config_dict + def to_dict(self) -> Dict[str, Any]: """ Serializes the component to a dictionary. @@ -123,10 +152,14 @@ def to_dict(self) -> Dict[str, Any]: generation_config=self._generation_config, safety_settings=self._safety_settings, tools=self._tools, + tool_config=self._tool_config, + system_instruction=self._system_instruction, streaming_callback=callback_name, ) if (tools := data["init_parameters"].get("tools")) is not None: data["init_parameters"]["tools"] = [Tool.to_dict(t) for t in tools] + if (tool_config := data["init_parameters"].get("tool_config")) is not None: + data["init_parameters"]["tool_config"] = self._tool_config_to_dict(tool_config) if (generation_config := data["init_parameters"].get("generation_config")) is not None: data["init_parameters"]["generation_config"] = self._generation_config_to_dict(generation_config) return data @@ -141,10 +174,23 @@ def from_dict(cls, data: Dict[str, Any]) -> "VertexAIGeminiChatGenerator": :returns: Deserialized component. """ + + def _tool_config_from_dict(config_dict: Dict[str, Any]) -> ToolConfig: + """Deserializes the ToolConfig object from a dictionary.""" + function_calling_config = config_dict["function_calling_config"] + return ToolConfig( + function_calling_config=ToolConfig.FunctionCallingConfig( + mode=function_calling_config["mode"], + allowed_function_names=function_calling_config.get("allowed_function_names"), + ) + ) + if (tools := data["init_parameters"].get("tools")) is not None: data["init_parameters"]["tools"] = [Tool.from_dict(t) for t in tools] if (generation_config := data["init_parameters"].get("generation_config")) is not None: data["init_parameters"]["generation_config"] = GenerationConfig.from_dict(generation_config) + if (tool_config := data["init_parameters"].get("tool_config")) is not None: + data["init_parameters"]["tool_config"] = _tool_config_from_dict(tool_config) if (serialized_callback_handler := data["init_parameters"].get("streaming_callback")) is not None: data["init_parameters"]["streaming_callback"] = deserialize_callable(serialized_callback_handler) return default_from_dict(cls, data) @@ -212,9 +258,6 @@ def run( new_message = self._message_to_part(messages[-1]) res = session.send_message( content=new_message, - generation_config=self._generation_config, - safety_settings=self._safety_settings, - tools=self._tools, stream=streaming_callback is not None, ) diff --git a/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/gemini.py b/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/gemini.py index 7394211bf..2b1c1b477 100644 --- a/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/gemini.py +++ b/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/gemini.py @@ -16,6 +16,7 @@ HarmCategory, Part, Tool, + ToolConfig, ) logger = logging.getLogger(__name__) @@ -58,6 +59,8 @@ def __init__( generation_config: Optional[Union[GenerationConfig, Dict[str, Any]]] = None, safety_settings: Optional[Dict[HarmCategory, HarmBlockThreshold]] = None, tools: Optional[List[Tool]] = None, + tool_config: Optional[ToolConfig] = None, + system_instruction: Optional[Union[str, ByteStream, Part]] = None, streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, ): """ @@ -86,6 +89,8 @@ def __init__( :param tools: List of tools to use when generating content. See the documentation for [Tool](https://cloud.google.com/python/docs/reference/aiplatform/latest/vertexai.generative_models.Tool) the list of supported arguments. + :param tool_config: The tool config to use. See the documentation for [ToolConfig](https://cloud.google.com/vertex-ai/generative-ai/docs/reference/python/latest/vertexai.generative_models.ToolConfig) + :param system_instruction: Default system instruction to use for generating content. :param streaming_callback: A callback function that is called when a new token is received from the stream. The callback function accepts StreamingChunk as an argument. """ @@ -96,13 +101,25 @@ def __init__( self._model_name = model self._project_id = project_id self._location = location - self._model = GenerativeModel(self._model_name) + # model parameters self._generation_config = generation_config self._safety_settings = safety_settings self._tools = tools + self._tool_config = tool_config + self._system_instruction = system_instruction self._streaming_callback = streaming_callback + # except streaming_callback, all other model parameters can be passed during initialization + self._model = GenerativeModel( + self._model_name, + generation_config=self._generation_config, + safety_settings=self._safety_settings, + tools=self._tools, + tool_config=self._tool_config, + system_instruction=self._system_instruction, + ) + def _generation_config_to_dict(self, config: Union[GenerationConfig, Dict[str, Any]]) -> Dict[str, Any]: if isinstance(config, dict): return config @@ -115,6 +132,18 @@ def _generation_config_to_dict(self, config: Union[GenerationConfig, Dict[str, A "stop_sequences": config._raw_generation_config.stop_sequences, } + def _tool_config_to_dict(self, tool_config: ToolConfig) -> Dict[str, Any]: + """Serializes the ToolConfig object into a dictionary.""" + + mode = tool_config._gapic_tool_config.function_calling_config.mode + allowed_function_names = tool_config._gapic_tool_config.function_calling_config.allowed_function_names + config_dict = {"function_calling_config": {"mode": mode}} + + if allowed_function_names: + config_dict["function_calling_config"]["allowed_function_names"] = allowed_function_names + + return config_dict + def to_dict(self) -> Dict[str, Any]: """ Serializes the component to a dictionary. @@ -132,10 +161,14 @@ def to_dict(self) -> Dict[str, Any]: generation_config=self._generation_config, safety_settings=self._safety_settings, tools=self._tools, + tool_config=self._tool_config, + system_instruction=self._system_instruction, streaming_callback=callback_name, ) if (tools := data["init_parameters"].get("tools")) is not None: data["init_parameters"]["tools"] = [Tool.to_dict(t) for t in tools] + if (tool_config := data["init_parameters"].get("tool_config")) is not None: + data["init_parameters"]["tool_config"] = self._tool_config_to_dict(tool_config) if (generation_config := data["init_parameters"].get("generation_config")) is not None: data["init_parameters"]["generation_config"] = self._generation_config_to_dict(generation_config) return data @@ -150,10 +183,23 @@ def from_dict(cls, data: Dict[str, Any]) -> "VertexAIGeminiGenerator": :returns: Deserialized component. """ + + def _tool_config_from_dict(config_dict: Dict[str, Any]) -> ToolConfig: + """Deserializes the ToolConfig object from a dictionary.""" + function_calling_config = config_dict["function_calling_config"] + return ToolConfig( + function_calling_config=ToolConfig.FunctionCallingConfig( + mode=function_calling_config["mode"], + allowed_function_names=function_calling_config.get("allowed_function_names"), + ) + ) + if (tools := data["init_parameters"].get("tools")) is not None: data["init_parameters"]["tools"] = [Tool.from_dict(t) for t in tools] if (generation_config := data["init_parameters"].get("generation_config")) is not None: data["init_parameters"]["generation_config"] = GenerationConfig.from_dict(generation_config) + if (tool_config := data["init_parameters"].get("tool_config")) is not None: + data["init_parameters"]["tool_config"] = _tool_config_from_dict(tool_config) if (serialized_callback_handler := data["init_parameters"].get("streaming_callback")) is not None: data["init_parameters"]["streaming_callback"] = deserialize_callable(serialized_callback_handler) return default_from_dict(cls, data) @@ -188,11 +234,9 @@ def run( converted_parts = [self._convert_part(p) for p in parts] contents = [Content(parts=converted_parts, role="user")] + res = self._model.generate_content( contents=contents, - generation_config=self._generation_config, - safety_settings=self._safety_settings, - tools=self._tools, stream=streaming_callback is not None, ) self._model.start_chat() diff --git a/integrations/google_vertex/tests/chat/test_gemini.py b/integrations/google_vertex/tests/chat/test_gemini.py index ab21008fb..87b43d66b 100644 --- a/integrations/google_vertex/tests/chat/test_gemini.py +++ b/integrations/google_vertex/tests/chat/test_gemini.py @@ -13,6 +13,7 @@ HarmCategory, Part, Tool, + ToolConfig, ) from haystack_integrations.components.generators.google_vertex import VertexAIGeminiChatGenerator @@ -60,6 +61,12 @@ def test_init(mock_vertexai_init, _mock_generative_model): safety_settings = {HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH} tool = Tool(function_declarations=[GET_CURRENT_WEATHER_FUNC]) + tool_config = ToolConfig( + function_calling_config=ToolConfig.FunctionCallingConfig( + mode=ToolConfig.FunctionCallingConfig.Mode.ANY, + allowed_function_names=["get_current_weather_func"], + ) + ) gemini = VertexAIGeminiChatGenerator( project_id="TestID123", @@ -67,12 +74,16 @@ def test_init(mock_vertexai_init, _mock_generative_model): generation_config=generation_config, safety_settings=safety_settings, tools=[tool], + tool_config=tool_config, + system_instruction="Please provide brief answers.", ) mock_vertexai_init.assert_called() assert gemini._model_name == "gemini-1.5-flash" assert gemini._generation_config == generation_config assert gemini._safety_settings == safety_settings assert gemini._tools == [tool] + assert gemini._tool_config == tool_config + assert gemini._system_instruction == "Please provide brief answers." @patch("haystack_integrations.components.generators.google_vertex.chat.gemini.vertexai_init") @@ -92,6 +103,8 @@ def test_to_dict(_mock_vertexai_init, _mock_generative_model): "safety_settings": None, "streaming_callback": None, "tools": None, + "tool_config": None, + "system_instruction": None, }, } @@ -110,12 +123,20 @@ def test_to_dict_with_params(_mock_vertexai_init, _mock_generative_model): safety_settings = {HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH} tool = Tool(function_declarations=[GET_CURRENT_WEATHER_FUNC]) + tool_config = ToolConfig( + function_calling_config=ToolConfig.FunctionCallingConfig( + mode=ToolConfig.FunctionCallingConfig.Mode.ANY, + allowed_function_names=["get_current_weather_func"], + ) + ) gemini = VertexAIGeminiChatGenerator( project_id="TestID123", generation_config=generation_config, safety_settings=safety_settings, tools=[tool], + tool_config=tool_config, + system_instruction="Please provide brief answers.", ) assert gemini.to_dict() == { @@ -155,6 +176,13 @@ def test_to_dict_with_params(_mock_vertexai_init, _mock_generative_model): ] } ], + "tool_config": { + "function_calling_config": { + "mode": ToolConfig.FunctionCallingConfig.Mode.ANY, + "allowed_function_names": ["get_current_weather_func"], + } + }, + "system_instruction": "Please provide brief answers.", }, } @@ -180,6 +208,8 @@ def test_from_dict(_mock_vertexai_init, _mock_generative_model): assert gemini._project_id == "TestID123" assert gemini._safety_settings is None assert gemini._tools is None + assert gemini._tool_config is None + assert gemini._system_instruction is None assert gemini._generation_config is None @@ -222,6 +252,13 @@ def test_from_dict_with_param(_mock_vertexai_init, _mock_generative_model): ] } ], + "tool_config": { + "function_calling_config": { + "mode": ToolConfig.FunctionCallingConfig.Mode.ANY, + "allowed_function_names": ["get_current_weather_func"], + } + }, + "system_instruction": "Please provide brief answers.", "streaming_callback": None, }, } @@ -231,7 +268,12 @@ def test_from_dict_with_param(_mock_vertexai_init, _mock_generative_model): assert gemini._project_id == "TestID123" assert gemini._safety_settings == {HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH} assert repr(gemini._tools) == repr([Tool(function_declarations=[GET_CURRENT_WEATHER_FUNC])]) + assert isinstance(gemini._tool_config, ToolConfig) assert isinstance(gemini._generation_config, GenerationConfig) + assert gemini._system_instruction == "Please provide brief answers." + assert ( + gemini._tool_config._gapic_tool_config.function_calling_config.mode == ToolConfig.FunctionCallingConfig.Mode.ANY + ) @patch("haystack_integrations.components.generators.google_vertex.chat.gemini.GenerativeModel") diff --git a/integrations/google_vertex/tests/test_gemini.py b/integrations/google_vertex/tests/test_gemini.py index bb96ec409..1543f3ccf 100644 --- a/integrations/google_vertex/tests/test_gemini.py +++ b/integrations/google_vertex/tests/test_gemini.py @@ -9,6 +9,7 @@ HarmBlockThreshold, HarmCategory, Tool, + ToolConfig, ) from haystack_integrations.components.generators.google_vertex import VertexAIGeminiGenerator @@ -48,6 +49,12 @@ def test_init(mock_vertexai_init, _mock_generative_model): safety_settings = {HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH} tool = Tool(function_declarations=[GET_CURRENT_WEATHER_FUNC]) + tool_config = ToolConfig( + function_calling_config=ToolConfig.FunctionCallingConfig( + mode=ToolConfig.FunctionCallingConfig.Mode.ANY, + allowed_function_names=["get_current_weather_func"], + ) + ) gemini = VertexAIGeminiGenerator( project_id="TestID123", @@ -55,12 +62,16 @@ def test_init(mock_vertexai_init, _mock_generative_model): generation_config=generation_config, safety_settings=safety_settings, tools=[tool], + tool_config=tool_config, + system_instruction="Please provide brief answers.", ) mock_vertexai_init.assert_called() assert gemini._model_name == "gemini-1.5-flash" assert gemini._generation_config == generation_config assert gemini._safety_settings == safety_settings assert gemini._tools == [tool] + assert gemini._tool_config == tool_config + assert gemini._system_instruction == "Please provide brief answers." @patch("haystack_integrations.components.generators.google_vertex.gemini.vertexai_init") @@ -80,6 +91,8 @@ def test_to_dict(_mock_vertexai_init, _mock_generative_model): "safety_settings": None, "streaming_callback": None, "tools": None, + "tool_config": None, + "system_instruction": None, }, } @@ -98,12 +111,20 @@ def test_to_dict_with_params(_mock_vertexai_init, _mock_generative_model): safety_settings = {HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH} tool = Tool(function_declarations=[GET_CURRENT_WEATHER_FUNC]) + tool_config = ToolConfig( + function_calling_config=ToolConfig.FunctionCallingConfig( + mode=ToolConfig.FunctionCallingConfig.Mode.ANY, + allowed_function_names=["get_current_weather_func"], + ) + ) gemini = VertexAIGeminiGenerator( project_id="TestID123", generation_config=generation_config, safety_settings=safety_settings, tools=[tool], + tool_config=tool_config, + system_instruction="Please provide brief answers.", ) assert gemini.to_dict() == { "type": "haystack_integrations.components.generators.google_vertex.gemini.VertexAIGeminiGenerator", @@ -142,6 +163,13 @@ def test_to_dict_with_params(_mock_vertexai_init, _mock_generative_model): ] } ], + "tool_config": { + "function_calling_config": { + "mode": ToolConfig.FunctionCallingConfig.Mode.ANY, + "allowed_function_names": ["get_current_weather_func"], + } + }, + "system_instruction": "Please provide brief answers.", }, } @@ -159,6 +187,8 @@ def test_from_dict(_mock_vertexai_init, _mock_generative_model): "safety_settings": None, "tools": None, "streaming_callback": None, + "tool_config": None, + "system_instruction": None, }, } ) @@ -167,6 +197,8 @@ def test_from_dict(_mock_vertexai_init, _mock_generative_model): assert gemini._project_id == "TestID123" assert gemini._safety_settings is None assert gemini._tools is None + assert gemini._tool_config is None + assert gemini._system_instruction is None assert gemini._generation_config is None @@ -210,6 +242,13 @@ def test_from_dict_with_param(_mock_vertexai_init, _mock_generative_model): } ], "streaming_callback": None, + "tool_config": { + "function_calling_config": { + "mode": ToolConfig.FunctionCallingConfig.Mode.ANY, + "allowed_function_names": ["get_current_weather_func"], + } + }, + "system_instruction": "Please provide brief answers.", }, } ) @@ -219,6 +258,11 @@ def test_from_dict_with_param(_mock_vertexai_init, _mock_generative_model): assert gemini._safety_settings == {HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH} assert repr(gemini._tools) == repr([Tool(function_declarations=[GET_CURRENT_WEATHER_FUNC])]) assert isinstance(gemini._generation_config, GenerationConfig) + assert isinstance(gemini._tool_config, ToolConfig) + assert gemini._system_instruction == "Please provide brief answers." + assert ( + gemini._tool_config._gapic_tool_config.function_calling_config.mode == ToolConfig.FunctionCallingConfig.Mode.ANY + ) @patch("haystack_integrations.components.generators.google_vertex.gemini.GenerativeModel") From cdb7dffad7c58e713491e244d98161e79d13a29b Mon Sep 17 00:00:00 2001 From: Stefano Fiorucci Date: Thu, 26 Sep 2024 17:21:17 +0200 Subject: [PATCH 012/229] pin llama-cpp-python<0.3.0 (#1111) --- integrations/llama_cpp/pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/integrations/llama_cpp/pyproject.toml b/integrations/llama_cpp/pyproject.toml index d57bd9fc9..4fe0d82b0 100644 --- a/integrations/llama_cpp/pyproject.toml +++ b/integrations/llama_cpp/pyproject.toml @@ -26,7 +26,7 @@ classifiers = [ "Programming Language :: Python :: Implementation :: CPython", "Programming Language :: Python :: Implementation :: PyPy", ] -dependencies = ["haystack-ai", "llama-cpp-python>=0.2.87"] +dependencies = ["haystack-ai", "llama-cpp-python>=0.2.87,<0.3.0"] [project.urls] Documentation = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/llama_cpp#readme" From 907c10b8d07b2fb50f2be9870f33836dedf4a167 Mon Sep 17 00:00:00 2001 From: Stefano Fiorucci Date: Thu, 26 Sep 2024 17:50:40 +0200 Subject: [PATCH 013/229] chore: update ruff linting scripts and settings (#1105) * updates * more changes * lint * more changes * more changes * mmore and more changes * right concurrency group for anthropic * apply suggestions --- .github/workflows/anthropic.yml | 2 +- integrations/amazon_bedrock/pyproject.toml | 10 ++++++---- integrations/amazon_sagemaker/pyproject.toml | 10 ++++++---- integrations/anthropic/pyproject.toml | 10 ++++++---- integrations/astra/pyproject.toml | 9 +++++---- integrations/chroma/pyproject.toml | 20 ++++++++++--------- .../chroma/tests/test_document_store.py | 2 +- integrations/cohere/pyproject.toml | 14 +++++++------ .../tests/test_cohere_chat_generator.py | 1 + .../cohere/tests/test_cohere_generator.py | 1 + .../cohere/tests/test_cohere_ranker.py | 1 + .../cohere/tests/test_document_embedder.py | 1 + .../cohere/tests/test_text_embedder.py | 1 + integrations/deepeval/example/example.py | 1 + integrations/deepeval/pyproject.toml | 14 +++++++------ integrations/deepeval/tests/test_evaluator.py | 2 +- integrations/elasticsearch/pyproject.toml | 14 +++++++------ integrations/fastembed/pyproject.toml | 14 +++++++------ integrations/google_ai/pyproject.toml | 10 ++++++---- integrations/google_vertex/pyproject.toml | 10 ++++++---- .../instructor_embedders/pyproject.toml | 14 +++++++------ integrations/jina/pyproject.toml | 14 +++++++------ integrations/langfuse/example/basic_rag.py | 1 + integrations/langfuse/pyproject.toml | 17 ++++++++-------- .../langfuse/tests/test_langfuse_span.py | 2 ++ integrations/langfuse/tests/test_tracer.py | 2 +- integrations/langfuse/tests/test_tracing.py | 9 +++++---- integrations/llama_cpp/pyproject.toml | 14 +++++++------ integrations/mistral/pyproject.toml | 14 +++++++------ .../tests/test_mistral_chat_generator.py | 3 ++- .../tests/test_mistral_document_embedder.py | 1 + .../tests/test_mistral_text_embedder.py | 1 + integrations/mongodb_atlas/pyproject.toml | 14 +++++++------ integrations/nvidia/pyproject.toml | 14 +++++++------ integrations/ollama/pyproject.toml | 14 +++++++------ integrations/opensearch/pyproject.toml | 14 +++++++------ integrations/optimum/example/example.py | 4 ++-- integrations/optimum/pyproject.toml | 8 +++++--- .../tests/test_optimum_document_embedder.py | 13 ++++++------ .../tests/test_optimum_text_embedder.py | 5 +++-- integrations/pgvector/pyproject.toml | 14 +++++++------ integrations/pinecone/pyproject.toml | 10 ++++++---- integrations/qdrant/pyproject.toml | 10 ++++++---- integrations/ragas/pyproject.toml | 14 +++++++------ integrations/ragas/tests/test_evaluator.py | 3 ++- integrations/snowflake/pyproject.toml | 14 +++++++------ .../tests/test_snowflake_table_retriever.py | 12 ++++++++--- integrations/unstructured/pyproject.toml | 14 +++++++------ .../unstructured/tests/test_converter.py | 1 + integrations/weaviate/pyproject.toml | 14 +++++++------ 50 files changed, 253 insertions(+), 178 deletions(-) diff --git a/.github/workflows/anthropic.yml b/.github/workflows/anthropic.yml index c4cdeb2d1..52ba5c9d4 100644 --- a/.github/workflows/anthropic.yml +++ b/.github/workflows/anthropic.yml @@ -15,7 +15,7 @@ defaults: working-directory: integrations/anthropic concurrency: - group: cohere-${{ github.head_ref }} + group: anthropic-${{ github.head_ref }} cancel-in-progress: true env: diff --git a/integrations/amazon_bedrock/pyproject.toml b/integrations/amazon_bedrock/pyproject.toml index f4a410dbd..1298abfab 100644 --- a/integrations/amazon_bedrock/pyproject.toml +++ b/integrations/amazon_bedrock/pyproject.toml @@ -67,7 +67,7 @@ typing = "mypy --install-types --non-interactive --explicit-package-bases {args: style = ["ruff check {args:.}", "black --check --diff {args:.}"] -fmt = ["black {args:.}", "ruff --fix {args:.}", "style"] +fmt = ["black {args:.}", "ruff check --fix {args:.}", "style"] all = ["style", "typing"] @@ -79,6 +79,8 @@ skip-string-normalization = true [tool.ruff] target-version = "py38" line-length = 120 + +[tool.ruff.lint] select = [ "A", "ARG", @@ -128,13 +130,13 @@ unfixable = [ "F401", ] -[tool.ruff.isort] +[tool.ruff.lint.isort] known-first-party = ["haystack_integrations"] -[tool.ruff.flake8-tidy-imports] +[tool.ruff.lint.flake8-tidy-imports] ban-relative-imports = "parents" -[tool.ruff.per-file-ignores] +[tool.ruff.lint.per-file-ignores] # Tests can use magic values, assertions, and relative imports "tests/**/*" = ["PLR2004", "S101", "TID252"] diff --git a/integrations/amazon_sagemaker/pyproject.toml b/integrations/amazon_sagemaker/pyproject.toml index f8050bb48..a25b806f6 100644 --- a/integrations/amazon_sagemaker/pyproject.toml +++ b/integrations/amazon_sagemaker/pyproject.toml @@ -70,7 +70,7 @@ dependencies = ["black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243"] [tool.hatch.envs.lint.scripts] typing = "mypy --install-types --non-interactive --explicit-package-bases {args:src/ tests}" style = ["ruff check {args:.}", "black --check --diff {args:.}"] -fmt = ["black {args:.}", "ruff --fix {args:.}", "style"] +fmt = ["black {args:.}", "ruff check --fix {args:.}", "style"] all = ["style", "typing"] [tool.black] @@ -81,6 +81,8 @@ skip-string-normalization = true [tool.ruff] target-version = "py38" line-length = 120 + +[tool.ruff.lint] select = [ "A", "ARG", @@ -131,13 +133,13 @@ unfixable = [ "F401", ] -[tool.ruff.isort] +[tool.ruff.lint.isort] known-first-party = ["haystack_integrations"] -[tool.ruff.flake8-tidy-imports] +[tool.ruff.lint.flake8-tidy-imports] ban-relative-imports = "parents" -[tool.ruff.per-file-ignores] +[tool.ruff.lint.per-file-ignores] # Tests can use magic values, assertions, and relative imports "tests/**/*" = ["PLR2004", "S101", "TID252"] diff --git a/integrations/anthropic/pyproject.toml b/integrations/anthropic/pyproject.toml index e1d3fa867..987f017be 100644 --- a/integrations/anthropic/pyproject.toml +++ b/integrations/anthropic/pyproject.toml @@ -67,7 +67,7 @@ typing = "mypy --install-types --non-interactive --explicit-package-bases {args: style = ["ruff check {args:.}", "black --check --diff {args:.}"] -fmt = ["black {args:.}", "ruff --fix {args:.}", "style"] +fmt = ["black {args:.}", "ruff check --fix {args:.}", "style"] all = ["style", "typing"] @@ -79,6 +79,8 @@ skip-string-normalization = true [tool.ruff] target-version = "py38" line-length = 120 + +[tool.ruff.lint] select = [ "A", "ARG", @@ -129,13 +131,13 @@ unfixable = [ "F401", ] -[tool.ruff.isort] +[tool.ruff.lint.isort] known-first-party = ["haystack_integrations"] -[tool.ruff.flake8-tidy-imports] +[tool.ruff.lint.flake8-tidy-imports] ban-relative-imports = "parents" -[tool.ruff.per-file-ignores] +[tool.ruff.lint.per-file-ignores] # Tests can use magic values, assertions, and relative imports "tests/**/*" = ["PLR2004", "S101", "TID252"] diff --git a/integrations/astra/pyproject.toml b/integrations/astra/pyproject.toml index 7d543ddc9..25bcf20b8 100644 --- a/integrations/astra/pyproject.toml +++ b/integrations/astra/pyproject.toml @@ -78,7 +78,9 @@ skip-string-normalization = true [tool.ruff] target-version = "py38" line-length = 120 -lint.select = [ + +[tool.ruff.lint] +select = [ "A", "ARG", "B", @@ -105,7 +107,7 @@ lint.select = [ "W", "YTT", ] -lint.ignore = [ +ignore = [ # Allow non-abstract empty methods in abstract base classes "B027", # Allow boolean positional values in function calls, like `dict.get(... True)` @@ -121,11 +123,10 @@ lint.ignore = [ "PLR0913", "PLR0915", ] -lint.unfixable = [ +unfixable = [ # Don't touch unused imports "F401", ] -lint.exclude = ["example"] [tool.ruff.lint.isort] known-first-party = ["haystack_integrations"] diff --git a/integrations/chroma/pyproject.toml b/integrations/chroma/pyproject.toml index cebfa1b9d..2bffabfd8 100644 --- a/integrations/chroma/pyproject.toml +++ b/integrations/chroma/pyproject.toml @@ -70,8 +70,8 @@ dependencies = [ ] [tool.hatch.envs.lint.scripts] typing = "mypy --install-types --non-interactive --explicit-package-bases {args:src/ tests}" -style = ["ruff check {args:. --exclude tests/}", "black --check --diff {args:.}"] -fmt = ["black {args:.}", "ruff --fix {args:. --exclude tests/}", "style"] +style = ["ruff check {args:.}", "black --check --diff {args:.}"] +fmt = ["black {args:.}", "ruff check --fix {args:.}", "style"] all = ["style", "typing"] [tool.hatch.metadata] @@ -85,7 +85,9 @@ skip-string-normalization = true [tool.ruff] target-version = "py38" line-length = 120 -lint.select = [ + +[tool.ruff.lint] +select = [ "A", "ARG", "B", @@ -112,7 +114,7 @@ lint.select = [ "W", "YTT", ] -lint.ignore = [ +ignore = [ # Allow non-abstract empty methods in abstract base classes "B027", # Allow boolean positional values in function calls, like `dict.get(... True)` @@ -130,19 +132,19 @@ lint.ignore = [ # Ignore unused params "ARG002", ] -lint.unfixable = [ +unfixable = [ # Don't touch unused imports "F401", ] exclude = ["example"] -[tool.ruff.isort] -known-first-party = ["src", "example", "tests"] +[tool.ruff.lint.isort] +known-first-party = ["haystack_integrations"] -[tool.ruff.flake8-tidy-imports] +[tool.ruff.lint.flake8-tidy-imports] ban-relative-imports = "parents" -[tool.ruff.per-file-ignores] +[tool.ruff.lint.per-file-ignores] # Tests can use magic values, assertions, and relative imports "tests/**/*" = ["PLR2004", "S101", "TID252"] "example/**/*" = ["T201"] diff --git a/integrations/chroma/tests/test_document_store.py b/integrations/chroma/tests/test_document_store.py index 3a6952ff8..d33086945 100644 --- a/integrations/chroma/tests/test_document_store.py +++ b/integrations/chroma/tests/test_document_store.py @@ -3,10 +3,10 @@ # SPDX-License-Identifier: Apache-2.0 import logging import operator +import sys import uuid from typing import List from unittest import mock -import sys import numpy as np import pytest diff --git a/integrations/cohere/pyproject.toml b/integrations/cohere/pyproject.toml index add8dc150..d86165668 100644 --- a/integrations/cohere/pyproject.toml +++ b/integrations/cohere/pyproject.toml @@ -65,10 +65,10 @@ dependencies = ["black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243"] [tool.hatch.envs.lint.scripts] typing = "mypy --install-types --non-interactive --explicit-package-bases {args:src/ tests}" style = [ - "ruff check {args:. --exclude tests/}", + "ruff check {args:.}", "black --check --diff {args:.}", ] -fmt = ["black {args:.}", "ruff --fix {args:. --exclude tests/}", "style"] +fmt = ["black {args:.}", "ruff check --fix {args:}", "style"] all = ["style", "typing"] [tool.black] @@ -79,6 +79,8 @@ skip-string-normalization = true [tool.ruff] target-version = "py38" line-length = 120 + +[tool.ruff.lint] select = [ "A", "ARG", @@ -127,13 +129,13 @@ unfixable = [ "F401", ] -[tool.ruff.isort] -known-first-party = ["src"] +[tool.ruff.lint.isort] +known-first-party = ["haystack_integrations"] -[tool.ruff.flake8-tidy-imports] +[tool.ruff.lint.flake8-tidy-imports] ban-relative-imports = "parents" -[tool.ruff.per-file-ignores] +[tool.ruff.lint.per-file-ignores] # Tests can use magic values, assertions, and relative imports "tests/**/*" = ["PLR2004", "S101", "TID252"] diff --git a/integrations/cohere/tests/test_cohere_chat_generator.py b/integrations/cohere/tests/test_cohere_chat_generator.py index fe9b7f43e..175a6d14b 100644 --- a/integrations/cohere/tests/test_cohere_chat_generator.py +++ b/integrations/cohere/tests/test_cohere_chat_generator.py @@ -7,6 +7,7 @@ from haystack.components.generators.utils import print_streaming_chunk from haystack.dataclasses import ChatMessage, ChatRole, StreamingChunk from haystack.utils import Secret + from haystack_integrations.components.generators.cohere import CohereChatGenerator pytestmark = pytest.mark.chat_generators diff --git a/integrations/cohere/tests/test_cohere_generator.py b/integrations/cohere/tests/test_cohere_generator.py index 736b6bfbf..60ee6ac93 100644 --- a/integrations/cohere/tests/test_cohere_generator.py +++ b/integrations/cohere/tests/test_cohere_generator.py @@ -7,6 +7,7 @@ from cohere.core import ApiError from haystack.components.generators.utils import print_streaming_chunk from haystack.utils import Secret + from haystack_integrations.components.generators.cohere import CohereGenerator pytestmark = pytest.mark.generators diff --git a/integrations/cohere/tests/test_cohere_ranker.py b/integrations/cohere/tests/test_cohere_ranker.py index 670e662d4..ff861b39d 100644 --- a/integrations/cohere/tests/test_cohere_ranker.py +++ b/integrations/cohere/tests/test_cohere_ranker.py @@ -4,6 +4,7 @@ import pytest from haystack import Document from haystack.utils.auth import Secret + from haystack_integrations.components.rankers.cohere import CohereRanker pytestmark = pytest.mark.ranker diff --git a/integrations/cohere/tests/test_document_embedder.py b/integrations/cohere/tests/test_document_embedder.py index ffbf280e9..d69e1a5a2 100644 --- a/integrations/cohere/tests/test_document_embedder.py +++ b/integrations/cohere/tests/test_document_embedder.py @@ -6,6 +6,7 @@ import pytest from haystack import Document from haystack.utils import Secret + from haystack_integrations.components.embedders.cohere import CohereDocumentEmbedder pytestmark = pytest.mark.embedders diff --git a/integrations/cohere/tests/test_text_embedder.py b/integrations/cohere/tests/test_text_embedder.py index b4f3e234c..80f7c1a3e 100644 --- a/integrations/cohere/tests/test_text_embedder.py +++ b/integrations/cohere/tests/test_text_embedder.py @@ -5,6 +5,7 @@ import pytest from haystack.utils import Secret + from haystack_integrations.components.embedders.cohere import CohereTextEmbedder pytestmark = pytest.mark.embedders diff --git a/integrations/deepeval/example/example.py b/integrations/deepeval/example/example.py index e1265a739..97a26ef34 100644 --- a/integrations/deepeval/example/example.py +++ b/integrations/deepeval/example/example.py @@ -1,6 +1,7 @@ # A valid OpenAI API key is required to run this example. from haystack import Pipeline + from haystack_integrations.components.evaluators.deepeval import DeepEvalEvaluator, DeepEvalMetric QUESTIONS = [ diff --git a/integrations/deepeval/pyproject.toml b/integrations/deepeval/pyproject.toml index 44d89cb11..5d81fa0a5 100644 --- a/integrations/deepeval/pyproject.toml +++ b/integrations/deepeval/pyproject.toml @@ -60,7 +60,7 @@ dependencies = ["black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243"] [tool.hatch.envs.lint.scripts] typing = "mypy --install-types --non-interactive {args:src/}" style = ["ruff check {args:.}", "black --check --diff {args:.}"] -fmt = ["black {args:.}", "ruff --fix {args:.}", "style"] +fmt = ["black {args:.}", "ruff check --fix {args:.}", "style"] all = ["style", "typing"] [tool.black] @@ -71,6 +71,9 @@ skip-string-normalization = true [tool.ruff] target-version = "py38" line-length = 120 +exclude = ["example", "tests"] + +[tool.ruff.lint] select = [ "A", "ARG", @@ -121,15 +124,14 @@ unfixable = [ # Don't touch unused imports "F401", ] -extend-exclude = ["tests", "example"] -[tool.ruff.isort] -known-first-party = ["src"] +[tool.ruff.lint.isort] +known-first-party = ["haystack_integrations"] -[tool.ruff.flake8-tidy-imports] +[tool.ruff.lint.flake8-tidy-imports] ban-relative-imports = "all" -[tool.ruff.per-file-ignores] +[tool.ruff.lint.per-file-ignores] # Tests can use magic values, assertions, and relative imports "tests/**/*" = ["PLR2004", "S101", "TID252"] diff --git a/integrations/deepeval/tests/test_evaluator.py b/integrations/deepeval/tests/test_evaluator.py index 7d1946185..24b9ba7ea 100644 --- a/integrations/deepeval/tests/test_evaluator.py +++ b/integrations/deepeval/tests/test_evaluator.py @@ -5,10 +5,10 @@ from unittest.mock import patch import pytest +from deepeval.evaluate import BaseMetric, TestResult from haystack import DeserializationError from haystack_integrations.components.evaluators.deepeval import DeepEvalEvaluator, DeepEvalMetric -from deepeval.evaluate import TestResult, BaseMetric DEFAULT_QUESTIONS = [ "Which is the most popular global sport?", diff --git a/integrations/elasticsearch/pyproject.toml b/integrations/elasticsearch/pyproject.toml index 4e13b1c23..47b168f30 100644 --- a/integrations/elasticsearch/pyproject.toml +++ b/integrations/elasticsearch/pyproject.toml @@ -65,8 +65,8 @@ detached = true dependencies = ["black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243"] [tool.hatch.envs.lint.scripts] typing = "mypy --install-types --non-interactive --explicit-package-bases {args:src/ tests}" -style = ["ruff check {args:. --exclude tests/}", "black --check --diff {args:.}"] -fmt = ["black {args:.}", "ruff --fix {args:. --exclude tests/}", "style"] +style = ["ruff check {args:}", "black --check --diff {args:.}"] +fmt = ["black {args:.}", "ruff check --fix {args:}", "style"] all = ["style", "typing"] [tool.hatch.metadata] @@ -80,6 +80,8 @@ skip-string-normalization = true [tool.ruff] target-version = "py38" line-length = 120 + +[tool.ruff.lint] select = [ "A", "ARG", @@ -128,13 +130,13 @@ unfixable = [ "F401", ] -[tool.ruff.isort] -known-first-party = ["src"] +[tool.ruff.lint.isort] +known-first-party = ["haystack_integrations"] -[tool.ruff.flake8-tidy-imports] +[tool.ruff.lint.flake8-tidy-imports] ban-relative-imports = "parents" -[tool.ruff.per-file-ignores] +[tool.ruff.lint.per-file-ignores] # Tests can use magic values, assertions, and relative imports "tests/**/*" = ["PLR2004", "S101", "TID252"] diff --git a/integrations/fastembed/pyproject.toml b/integrations/fastembed/pyproject.toml index 4ebd765dd..69aba5562 100644 --- a/integrations/fastembed/pyproject.toml +++ b/integrations/fastembed/pyproject.toml @@ -66,8 +66,8 @@ detached = true dependencies = ["black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243", "numpy"] [tool.hatch.envs.lint.scripts] typing = "mypy --install-types --non-interactive --explicit-package-bases {args:src/ tests}" -style = ["ruff check {args:. --exclude tests/, examples/}", "black --check --diff {args:.}"] -fmt = ["black {args:.}", "ruff --fix {args:. --exclude tests/, examples/}", "style"] +style = ["ruff check {args:.}", "black --check --diff {args:.}"] +fmt = ["black {args:.}", "ruff check --fix {args:.}", "style"] all = ["style", "typing"] [tool.black] @@ -78,6 +78,8 @@ skip-string-normalization = true [tool.ruff] target-version = "py38" line-length = 120 + +[tool.ruff.lint] select = [ "A", "ARG", @@ -128,13 +130,13 @@ unfixable = [ "F401", ] -[tool.ruff.isort] -known-first-party = ["src"] +[tool.ruff.lint.isort] +known-first-party = ["haystack_integrations"] -[tool.ruff.flake8-tidy-imports] +[tool.ruff.lint.flake8-tidy-imports] ban-relative-imports = "parents" -[tool.ruff.per-file-ignores] +[tool.ruff.lint.per-file-ignores] # Tests can use magic values, assertions, and relative imports "tests/**/*" = ["PLR2004", "S101", "TID252"] # examples can contain "print" commands diff --git a/integrations/google_ai/pyproject.toml b/integrations/google_ai/pyproject.toml index db958a487..d06e0a53f 100644 --- a/integrations/google_ai/pyproject.toml +++ b/integrations/google_ai/pyproject.toml @@ -64,7 +64,7 @@ dependencies = ["black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243"] [tool.hatch.envs.lint.scripts] typing = "mypy --install-types --non-interactive --explicit-package-bases {args:src/ tests}" style = ["ruff check {args:.}", "black --check --diff {args:.}"] -fmt = ["black {args:.}", "ruff --fix {args:.}", "style"] +fmt = ["black {args:.}", "ruff check --fix {args:.}", "style"] all = ["style", "typing"] [tool.black] @@ -75,6 +75,8 @@ skip-string-normalization = true [tool.ruff] target-version = "py38" line-length = 120 + +[tool.ruff.lint] select = [ "A", "ARG", @@ -123,13 +125,13 @@ unfixable = [ "F401", ] -[tool.ruff.isort] +[tool.ruff.lint.isort] known-first-party = ["haystack_integrations"] -[tool.ruff.flake8-tidy-imports] +[tool.ruff.lint.flake8-tidy-imports] ban-relative-imports = "parents" -[tool.ruff.per-file-ignores] +[tool.ruff.lint.per-file-ignores] # Tests can use magic values, assertions, and relative imports "tests/**/*" = ["PLR2004", "S101", "TID252"] diff --git a/integrations/google_vertex/pyproject.toml b/integrations/google_vertex/pyproject.toml index 747bbecbf..a0cefbcd4 100644 --- a/integrations/google_vertex/pyproject.toml +++ b/integrations/google_vertex/pyproject.toml @@ -64,7 +64,7 @@ dependencies = ["black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243"] [tool.hatch.envs.lint.scripts] typing = "mypy --install-types --non-interactive --explicit-package-bases {args:src/ tests}" style = ["ruff check {args:.}", "black --check --diff {args:.}"] -fmt = ["black {args:.}", "ruff --fix {args:.}", "style"] +fmt = ["black {args:.}", "ruff check --fix {args:.}", "style"] all = ["style", "typing"] [tool.black] @@ -75,6 +75,8 @@ skip-string-normalization = true [tool.ruff] target-version = "py38" line-length = 120 + +[tool.ruff.lint] select = [ "A", "ARG", @@ -120,13 +122,13 @@ unfixable = [ "F401", ] -[tool.ruff.isort] +[tool.ruff.lint.isort] known-first-party = ["haystack_integrations"] -[tool.ruff.flake8-tidy-imports] +[tool.ruff.lint.flake8-tidy-imports] ban-relative-imports = "parents" -[tool.ruff.per-file-ignores] +[tool.ruff.lint.per-file-ignores] # Tests can use magic values, assertions, and relative imports "tests/**/*" = ["PLR2004", "S101", "TID252"] diff --git a/integrations/instructor_embedders/pyproject.toml b/integrations/instructor_embedders/pyproject.toml index 017062a47..458c0ae0c 100644 --- a/integrations/instructor_embedders/pyproject.toml +++ b/integrations/instructor_embedders/pyproject.toml @@ -88,8 +88,8 @@ dependencies = ["black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243"] [tool.hatch.envs.lint.scripts] typing = "mypy --install-types --non-interactive --explicit-package-bases {args:src/ tests}" -style = ["ruff check {args:. --exclude tests/}", "black --check --diff {args:.}"] -fmt = ["black {args:.}", "ruff --fix {args:. --exclude tests/}", "style"] +style = ["ruff check {args:}", "black --check --diff {args:.}"] +fmt = ["black {args:.}", "ruff check --fix {args:}", "style"] all = ["style", "typing"] [tool.coverage.run] @@ -105,6 +105,8 @@ exclude_lines = ["no cov", "if __name__ == .__main__.:", "if TYPE_CHECKING:"] [tool.ruff] target-version = "py38" line-length = 120 + +[tool.ruff.lint] select = [ "A", "ARG", @@ -152,13 +154,13 @@ unfixable = [ "F401", ] -[tool.ruff.isort] -known-first-party = ["instructor_embedders"] +[tool.ruff.lint.isort] +known-first-party = ["haystack_integrations"] -[tool.ruff.flake8-tidy-imports] +[tool.ruff.lint.flake8-tidy-imports] ban-relative-imports = "parents" -[tool.ruff.per-file-ignores] +[tool.ruff.lint.per-file-ignores] # Tests can use magic values, assertions, and relative imports "tests/**/*" = ["PLR2004", "S101", "TID252"] diff --git a/integrations/jina/pyproject.toml b/integrations/jina/pyproject.toml index 908633686..cbd8df479 100644 --- a/integrations/jina/pyproject.toml +++ b/integrations/jina/pyproject.toml @@ -62,8 +62,8 @@ detached = true dependencies = ["black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243"] [tool.hatch.envs.lint.scripts] typing = "mypy --install-types --non-interactive --explicit-package-bases {args:src/ tests}" -style = ["ruff check {args:. --exclude tests/}", "black --check --diff {args:.}"] -fmt = ["black {args:.}", "ruff --fix {args:. --exclude tests/}", "style"] +style = ["ruff check {args:}", "black --check --diff {args:.}"] +fmt = ["black {args:.}", "ruff check --fix {args:}", "style"] all = ["style", "typing"] [tool.black] @@ -74,6 +74,8 @@ skip-string-normalization = true [tool.ruff] target-version = "py38" line-length = 120 + +[tool.ruff.lint] select = [ "A", "ARG", @@ -119,13 +121,13 @@ unfixable = [ "F401", ] -[tool.ruff.isort] -known-first-party = ["jina_haystack"] +[tool.ruff.lint.isort] +known-first-party = ["haystack_integrations"] -[tool.ruff.flake8-tidy-imports] +[tool.ruff.lint.flake8-tidy-imports] ban-relative-imports = "parents" -[tool.ruff.per-file-ignores] +[tool.ruff.lint.per-file-ignores] # Tests can use magic values, assertions, and relative imports "tests/**/*" = ["PLR2004", "S101", "TID252"] diff --git a/integrations/langfuse/example/basic_rag.py b/integrations/langfuse/example/basic_rag.py index 492a14d49..b1d5e620f 100644 --- a/integrations/langfuse/example/basic_rag.py +++ b/integrations/langfuse/example/basic_rag.py @@ -10,6 +10,7 @@ from haystack.components.generators import OpenAIGenerator from haystack.components.retrievers import InMemoryEmbeddingRetriever from haystack.document_stores.in_memory import InMemoryDocumentStore + from haystack_integrations.components.connectors.langfuse import LangfuseConnector diff --git a/integrations/langfuse/pyproject.toml b/integrations/langfuse/pyproject.toml index 6f9213be7..61de4596c 100644 --- a/integrations/langfuse/pyproject.toml +++ b/integrations/langfuse/pyproject.toml @@ -69,8 +69,8 @@ dependencies = ["black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243"] [tool.hatch.envs.lint.scripts] typing = "mypy --install-types --non-interactive --explicit-package-bases {args:src/ tests}" -style = ["ruff check {args:. --exclude tests/}", "black --check --diff {args:.}"] -fmt = ["black {args:.}", "ruff --fix {args:. --exclude tests/}", "style"] +style = ["ruff check {args:}", "black --check --diff {args:.}"] +fmt = ["black {args:.}", "ruff check --fix {args:}", "style"] all = ["style", "typing"] [tool.hatch.metadata] @@ -84,7 +84,10 @@ skip-string-normalization = true [tool.ruff] target-version = "py38" line-length = 120 -lint.select = [ +exclude = ["example", "tests"] + +[tool.ruff.lint] +select = [ "A", "ARG", "B", @@ -110,8 +113,7 @@ lint.select = [ "W", "YTT", ] - -lint.ignore = [ +ignore = [ # Allow non-abstract empty methods in abstract base classes "B027", # Ignore checks for possible passwords @@ -127,14 +129,13 @@ lint.ignore = [ # Asserts "S101", ] -lint.unfixable = [ +unfixable = [ # Don't touch unused imports "F401", ] -extend-exclude = ["tests", "example"] [tool.ruff.lint.isort] -known-first-party = ["src"] +known-first-party = ["haystack_integrations"] [tool.ruff.lint.flake8-tidy-imports] ban-relative-imports = "parents" diff --git a/integrations/langfuse/tests/test_langfuse_span.py b/integrations/langfuse/tests/test_langfuse_span.py index a5a5f2c13..7ea82ba97 100644 --- a/integrations/langfuse/tests/test_langfuse_span.py +++ b/integrations/langfuse/tests/test_langfuse_span.py @@ -3,7 +3,9 @@ os.environ["HAYSTACK_CONTENT_TRACING_ENABLED"] = "true" from unittest.mock import Mock + from haystack.dataclasses import ChatMessage + from haystack_integrations.tracing.langfuse.tracer import LangfuseSpan diff --git a/integrations/langfuse/tests/test_tracer.py b/integrations/langfuse/tests/test_tracer.py index 241581a72..c6bf4acdf 100644 --- a/integrations/langfuse/tests/test_tracer.py +++ b/integrations/langfuse/tests/test_tracer.py @@ -1,5 +1,5 @@ import os -from unittest.mock import Mock, MagicMock, patch +from unittest.mock import MagicMock, Mock, patch from haystack_integrations.tracing.langfuse.tracer import LangfuseTracer diff --git a/integrations/langfuse/tests/test_tracing.py b/integrations/langfuse/tests/test_tracing.py index 05acf750e..936064e0a 100644 --- a/integrations/langfuse/tests/test_tracing.py +++ b/integrations/langfuse/tests/test_tracing.py @@ -1,16 +1,17 @@ import os import random import time -import pytest from urllib.parse import urlparse + +import pytest import requests -from requests.auth import HTTPBasicAuth from haystack import Pipeline from haystack.components.builders import ChatPromptBuilder -from haystack.dataclasses import ChatMessage -from haystack_integrations.components.connectors.langfuse import LangfuseConnector from haystack.components.generators.chat import OpenAIChatGenerator +from haystack.dataclasses import ChatMessage +from requests.auth import HTTPBasicAuth +from haystack_integrations.components.connectors.langfuse import LangfuseConnector from haystack_integrations.components.generators.anthropic import AnthropicChatGenerator from haystack_integrations.components.generators.cohere import CohereChatGenerator diff --git a/integrations/llama_cpp/pyproject.toml b/integrations/llama_cpp/pyproject.toml index 4fe0d82b0..673df575a 100644 --- a/integrations/llama_cpp/pyproject.toml +++ b/integrations/llama_cpp/pyproject.toml @@ -70,15 +70,15 @@ dependencies = ["black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243"] [tool.hatch.envs.lint.scripts] typing = "mypy --install-types --non-interactive --explicit-package-bases {args:src/ tests}" -style = ["ruff check {args:. --exclude tests/, examples/}", "black --check --diff {args:.}"] -fmt = ["black {args:.}", "ruff --fix {args:. --exclude tests/, examples/}", "style"] +style = ["ruff check {args:.}", "black --check --diff {args:.}"] +fmt = ["black {args:.}", "ruff check --fix {args:.}", "style"] all = ["style", "typing"] [tool.hatch.metadata] allow-direct-references = true -[tool.ruff.isort] -known-first-party = ["src"] +[tool.ruff.lint.isort] +known-first-party = ["haystack_integrations"] [tool.black] target-version = ["py38"] @@ -88,6 +88,8 @@ skip-string-normalization = true [tool.ruff] target-version = "py38" line-length = 120 + +[tool.ruff.lint] select = [ "A", "ARG", @@ -133,10 +135,10 @@ unfixable = [ "F401", ] -[tool.ruff.flake8-tidy-imports] +[tool.ruff.lint.flake8-tidy-imports] ban-relative-imports = "parents" -[tool.ruff.per-file-ignores] +[tool.ruff.lint.per-file-ignores] # Tests can use magic values, assertions, and relative imports "tests/**/*" = ["PLR2004", "S101", "TID252"] # Examples can print their output diff --git a/integrations/mistral/pyproject.toml b/integrations/mistral/pyproject.toml index 8e28c2c06..16f332331 100644 --- a/integrations/mistral/pyproject.toml +++ b/integrations/mistral/pyproject.toml @@ -65,10 +65,10 @@ dependencies = ["black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243"] [tool.hatch.envs.lint.scripts] typing = "mypy --install-types --non-interactive --explicit-package-bases {args:src/ tests}" style = [ - "ruff check {args:. --exclude tests/}", + "ruff check {args:}", "black --check --diff {args:.}", ] -fmt = ["black {args:.}", "ruff --fix {args:. --exclude tests/}", "style"] +fmt = ["black {args:.}", "ruff check --fix {args:}", "style"] all = ["style", "typing"] [tool.black] @@ -79,6 +79,8 @@ skip-string-normalization = true [tool.ruff] target-version = "py38" line-length = 120 + +[tool.ruff.lint] select = [ "A", "ARG", @@ -127,13 +129,13 @@ unfixable = [ "F401", ] -[tool.ruff.isort] -known-first-party = ["src"] +[tool.ruff.lint.isort] +known-first-party = ["haystack_integrations"] -[tool.ruff.flake8-tidy-imports] +[tool.ruff.lint.flake8-tidy-imports] ban-relative-imports = "parents" -[tool.ruff.per-file-ignores] +[tool.ruff.lint.per-file-ignores] # Tests can use magic values, assertions, and relative imports "tests/**/*" = ["PLR2004", "S101", "TID252"] diff --git a/integrations/mistral/tests/test_mistral_chat_generator.py b/integrations/mistral/tests/test_mistral_chat_generator.py index 181397c00..3c95f19db 100644 --- a/integrations/mistral/tests/test_mistral_chat_generator.py +++ b/integrations/mistral/tests/test_mistral_chat_generator.py @@ -7,11 +7,12 @@ from haystack.components.generators.utils import print_streaming_chunk from haystack.dataclasses import ChatMessage, StreamingChunk from haystack.utils.auth import Secret -from haystack_integrations.components.generators.mistral.chat.chat_generator import MistralChatGenerator from openai import OpenAIError from openai.types.chat import ChatCompletion, ChatCompletionMessage from openai.types.chat.chat_completion import Choice +from haystack_integrations.components.generators.mistral.chat.chat_generator import MistralChatGenerator + @pytest.fixture def chat_messages(): diff --git a/integrations/mistral/tests/test_mistral_document_embedder.py b/integrations/mistral/tests/test_mistral_document_embedder.py index 6e5c11759..4e710e45d 100644 --- a/integrations/mistral/tests/test_mistral_document_embedder.py +++ b/integrations/mistral/tests/test_mistral_document_embedder.py @@ -6,6 +6,7 @@ import pytest from haystack import Document from haystack.utils import Secret + from haystack_integrations.components.embedders.mistral.document_embedder import MistralDocumentEmbedder pytestmark = pytest.mark.embedders diff --git a/integrations/mistral/tests/test_mistral_text_embedder.py b/integrations/mistral/tests/test_mistral_text_embedder.py index af004b022..175a96e0f 100644 --- a/integrations/mistral/tests/test_mistral_text_embedder.py +++ b/integrations/mistral/tests/test_mistral_text_embedder.py @@ -5,6 +5,7 @@ import pytest from haystack.utils import Secret + from haystack_integrations.components.embedders.mistral.text_embedder import MistralTextEmbedder pytestmark = pytest.mark.embedders diff --git a/integrations/mongodb_atlas/pyproject.toml b/integrations/mongodb_atlas/pyproject.toml index f8fcf4254..95ed6c03a 100644 --- a/integrations/mongodb_atlas/pyproject.toml +++ b/integrations/mongodb_atlas/pyproject.toml @@ -66,8 +66,8 @@ detached = true dependencies = ["black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243"] [tool.hatch.envs.lint.scripts] typing = "mypy --install-types --non-interactive --explicit-package-bases {args:src/ tests}" -style = ["ruff check {args:. --exclude tests/, examples/}", "black --check --diff {args:.}"] -fmt = ["black {args:.}", "ruff --fix {args:. --exclude tests/, examples/}", "style"] +style = ["ruff check {args:.}", "black --check --diff {args:.}"] +fmt = ["black {args:.}", "ruff check --fix {args:.}", "style"] all = ["style", "typing"] [tool.black] @@ -78,6 +78,8 @@ skip-string-normalization = true [tool.ruff] target-version = "py38" line-length = 120 + +[tool.ruff.lint] select = [ "A", "ARG", @@ -126,13 +128,13 @@ unfixable = [ "F401", ] -[tool.ruff.isort] -known-first-party = ["src"] +[tool.ruff.lint.isort] +known-first-party = ["haystack_integrations"] -[tool.ruff.flake8-tidy-imports] +[tool.ruff.lint.flake8-tidy-imports] ban-relative-imports = "parents" -[tool.ruff.per-file-ignores] +[tool.ruff.lint.per-file-ignores] # Tests can use magic values, assertions, and relative imports "tests/**/*" = ["PLR2004", "S101", "TID252"] # examples can contain "print" commands diff --git a/integrations/nvidia/pyproject.toml b/integrations/nvidia/pyproject.toml index 82fb32b95..b5c6dd205 100644 --- a/integrations/nvidia/pyproject.toml +++ b/integrations/nvidia/pyproject.toml @@ -67,10 +67,10 @@ dependencies = ["black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243"] [tool.hatch.envs.lint.scripts] typing = "mypy --install-types --non-interactive --explicit-package-bases {args:src/ tests}" style = [ - "ruff check {args:. --exclude tests/}", + "ruff check {args:}", "black --check --diff {args:.}", ] -fmt = ["black {args:.}", "ruff --fix {args:. --exclude tests/}", "style"] +fmt = ["black {args:.}", "ruff check --fix {args:}", "style"] all = ["style", "typing"] [tool.black] @@ -81,6 +81,8 @@ skip-string-normalization = true [tool.ruff] target-version = "py38" line-length = 120 + +[tool.ruff.lint] select = [ "A", "ARG", @@ -129,13 +131,13 @@ unfixable = [ "F401", ] -[tool.ruff.isort] -known-first-party = ["src"] +[tool.ruff.lint.isort] +known-first-party = ["haystack_integrations"] -[tool.ruff.flake8-tidy-imports] +[tool.ruff.lint.flake8-tidy-imports] ban-relative-imports = "parents" -[tool.ruff.per-file-ignores] +[tool.ruff.lint.per-file-ignores] # Tests can use magic values, assertions, and relative imports "tests/**/*" = ["PLR2004", "S101", "TID252"] diff --git a/integrations/ollama/pyproject.toml b/integrations/ollama/pyproject.toml index 1174d3b78..bc8555140 100644 --- a/integrations/ollama/pyproject.toml +++ b/integrations/ollama/pyproject.toml @@ -71,15 +71,15 @@ dependencies = ["black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243"] [tool.hatch.envs.lint.scripts] typing = "mypy --install-types --non-interactive --explicit-package-bases {args:src/ tests}" -style = ["ruff check {args:. --exclude tests/, examples/ }", "black --check --diff {args:.}"] -fmt = ["black {args:.}", "ruff --fix {args:. --exclude tests/, examples/ }", "style"] +style = ["ruff check {args:.}", "black --check --diff {args:.}"] +fmt = ["black {args:.}", "ruff check --fix {args:.}", "style"] all = ["style", "typing"] [tool.hatch.metadata] allow-direct-references = true -[tool.ruff.isort] -known-first-party = ["src"] +[tool.ruff.lint.isort] +known-first-party = ["haystack_integrations"] [tool.black] target-version = ["py38"] @@ -89,6 +89,8 @@ skip-string-normalization = true [tool.ruff] target-version = "py38" line-length = 120 + +[tool.ruff.lint] select = [ "A", "ARG", @@ -134,10 +136,10 @@ unfixable = [ "F401", ] -[tool.ruff.flake8-tidy-imports] +[tool.ruff.lint.flake8-tidy-imports] ban-relative-imports = "parents" -[tool.ruff.per-file-ignores] +[tool.ruff.lint.per-file-ignores] # Tests can use magic values, assertions, and relative imports "tests/**/*" = ["PLR2004", "S101", "TID252"] # Examples can print their output diff --git a/integrations/opensearch/pyproject.toml b/integrations/opensearch/pyproject.toml index 6be86727e..24f1653bd 100644 --- a/integrations/opensearch/pyproject.toml +++ b/integrations/opensearch/pyproject.toml @@ -67,8 +67,8 @@ detached = true dependencies = ["black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243", "boto3"] [tool.hatch.envs.lint.scripts] typing = "mypy --install-types --non-interactive --explicit-package-bases {args:src/ tests}" -style = ["ruff check {args:. --exclude tests/}", "black --check --diff {args:.}"] -fmt = ["black {args:.}", "ruff check --fix {args:. --exclude tests/}", "style"] +style = ["ruff check {args:}", "black --check --diff {args:.}"] +fmt = ["black {args:.}", "ruff check --fix {args:}", "style"] all = ["style", "typing"] [tool.hatch.metadata] @@ -82,6 +82,8 @@ skip-string-normalization = true [tool.ruff] target-version = "py38" line-length = 120 + +[tool.ruff.lint] select = [ "A", "ARG", @@ -130,13 +132,13 @@ unfixable = [ "F401", ] -[tool.ruff.isort] -known-first-party = ["src"] +[tool.ruff.lint.isort] +known-first-party = ["haystack_integrations"] -[tool.ruff.flake8-tidy-imports] +[tool.ruff.lint.flake8-tidy-imports] ban-relative-imports = "parents" -[tool.ruff.per-file-ignores] +[tool.ruff.lint.per-file-ignores] # Tests can use magic values, assertions, and relative imports "tests/**/*" = ["PLR2004", "S101", "TID252"] diff --git a/integrations/optimum/example/example.py b/integrations/optimum/example/example.py index 0d86ce99b..a1e22f575 100644 --- a/integrations/optimum/example/example.py +++ b/integrations/optimum/example/example.py @@ -3,10 +3,10 @@ from haystack import Pipeline from haystack_integrations.components.embedders.optimum import ( - OptimumTextEmbedder, - OptimumEmbedderPooling, OptimumEmbedderOptimizationConfig, OptimumEmbedderOptimizationMode, + OptimumEmbedderPooling, + OptimumTextEmbedder, ) pipeline = Pipeline() diff --git a/integrations/optimum/pyproject.toml b/integrations/optimum/pyproject.toml index 2e0fb26a4..305af6042 100644 --- a/integrations/optimum/pyproject.toml +++ b/integrations/optimum/pyproject.toml @@ -82,14 +82,14 @@ dependencies = ["black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243"] [tool.hatch.envs.lint.scripts] typing = "mypy --install-types --non-interactive --explicit-package-bases {args:src/ tests}" style = ["ruff check {args:.}", "black --check --diff {args:.}"] -fmt = ["black {args:.}", "ruff --fix {args:.}", "style"] +fmt = ["black {args:.}", "ruff check --fix {args:.}", "style"] all = ["style", "typing"] [tool.hatch.metadata] allow-direct-references = true [tool.ruff.lint.isort] -known-first-party = ["src"] +known-first-party = ["haystack_integrations"] [tool.black] target-version = ["py38"] @@ -99,6 +99,9 @@ skip-string-normalization = true [tool.ruff] target-version = "py38" line-length = 120 +exclude = ["example", "tests"] + +[tool.ruff.lint] select = [ "A", "ARG", @@ -145,7 +148,6 @@ unfixable = [ # Don't touch unused imports "F401", ] -extend-exclude = ["tests", "example"] [tool.ruff.lint.flake8-tidy-imports] ban-relative-imports = "parents" diff --git a/integrations/optimum/tests/test_optimum_document_embedder.py b/integrations/optimum/tests/test_optimum_document_embedder.py index 9288bb688..7c8ca02e0 100644 --- a/integrations/optimum/tests/test_optimum_document_embedder.py +++ b/integrations/optimum/tests/test_optimum_document_embedder.py @@ -1,21 +1,22 @@ -from unittest.mock import MagicMock, patch -import tempfile import copy +import tempfile +from unittest.mock import MagicMock, patch import pytest from haystack.dataclasses import Document from haystack.utils.auth import Secret +from huggingface_hub.utils import RepositoryNotFoundError + from haystack_integrations.components.embedders.optimum import OptimumDocumentEmbedder -from haystack_integrations.components.embedders.optimum.pooling import OptimumEmbedderPooling from haystack_integrations.components.embedders.optimum.optimization import ( OptimumEmbedderOptimizationConfig, OptimumEmbedderOptimizationMode, ) +from haystack_integrations.components.embedders.optimum.pooling import OptimumEmbedderPooling from haystack_integrations.components.embedders.optimum.quantization import ( OptimumEmbedderQuantizationConfig, OptimumEmbedderQuantizationMode, ) -from huggingface_hub.utils import RepositoryNotFoundError @pytest.fixture @@ -147,9 +148,7 @@ def test_to_and_from_dict(self, mock_check_valid_model, mock_get_pooling_mode): assert embedder._backend.parameters.optimizer_settings is None assert embedder._backend.parameters.quantizer_settings is None - def test_to_and_from_dict_with_custom_init_parameters( - self, mock_check_valid_model, mock_get_pooling_mode - ): # noqa: ARG002 + def test_to_and_from_dict_with_custom_init_parameters(self, mock_check_valid_model, mock_get_pooling_mode): component = OptimumDocumentEmbedder( model="sentence-transformers/all-minilm-l6-v2", token=Secret.from_env_var("ENV_VAR", strict=False), diff --git a/integrations/optimum/tests/test_optimum_text_embedder.py b/integrations/optimum/tests/test_optimum_text_embedder.py index ad0e7d800..db42ec26d 100644 --- a/integrations/optimum/tests/test_optimum_text_embedder.py +++ b/integrations/optimum/tests/test_optimum_text_embedder.py @@ -2,17 +2,18 @@ import pytest from haystack.utils.auth import Secret +from huggingface_hub.utils import RepositoryNotFoundError + from haystack_integrations.components.embedders.optimum import OptimumTextEmbedder -from haystack_integrations.components.embedders.optimum.pooling import OptimumEmbedderPooling from haystack_integrations.components.embedders.optimum.optimization import ( OptimumEmbedderOptimizationConfig, OptimumEmbedderOptimizationMode, ) +from haystack_integrations.components.embedders.optimum.pooling import OptimumEmbedderPooling from haystack_integrations.components.embedders.optimum.quantization import ( OptimumEmbedderQuantizationConfig, OptimumEmbedderQuantizationMode, ) -from huggingface_hub.utils import RepositoryNotFoundError @pytest.fixture diff --git a/integrations/pgvector/pyproject.toml b/integrations/pgvector/pyproject.toml index 8f4c2447b..014d163bc 100644 --- a/integrations/pgvector/pyproject.toml +++ b/integrations/pgvector/pyproject.toml @@ -66,8 +66,8 @@ detached = true dependencies = ["black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243"] [tool.hatch.envs.lint.scripts] typing = "mypy --install-types --non-interactive --explicit-package-bases {args:src/ tests}" -style = ["ruff check {args:. --exclude tests/, examples/}", "black --check --diff {args:.}"] -fmt = ["black {args:.}", "ruff --fix {args:. --exclude tests/, examples/}", "style"] +style = ["ruff check {args:.}", "black --check --diff {args:.}"] +fmt = ["black {args:.}", "ruff check --fix {args:.}", "style"] all = ["style", "typing"] [tool.black] @@ -78,6 +78,8 @@ skip-string-normalization = true [tool.ruff] target-version = "py38" line-length = 120 + +[tool.ruff.lint] select = [ "A", "ARG", @@ -128,13 +130,13 @@ unfixable = [ "F401", ] -[tool.ruff.isort] -known-first-party = ["src"] +[tool.ruff.lint.isort] +known-first-party = ["haystack_integrations"] -[tool.ruff.flake8-tidy-imports] +[tool.ruff.lint.flake8-tidy-imports] ban-relative-imports = "parents" -[tool.ruff.per-file-ignores] +[tool.ruff.lint.per-file-ignores] # Tests can use magic values, assertions, and relative imports "tests/**/*" = ["PLR2004", "S101", "TID252"] # examples can contain "print" commands diff --git a/integrations/pinecone/pyproject.toml b/integrations/pinecone/pyproject.toml index 866385dd3..3f2e4d6bd 100644 --- a/integrations/pinecone/pyproject.toml +++ b/integrations/pinecone/pyproject.toml @@ -71,7 +71,7 @@ dependencies = ["black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243", "numpy"] [tool.hatch.envs.lint.scripts] typing = "mypy --install-types --non-interactive --explicit-package-bases {args:src/ tests}" style = ["ruff check {args:.}", "black --check --diff {args:.}"] -fmt = ["black {args:.}", "ruff --fix {args:.}", "style"] +fmt = ["black {args:.}", "ruff check --fix {args:.}", "style"] all = ["style", "typing"] [tool.hatch.metadata] @@ -85,6 +85,8 @@ skip-string-normalization = true [tool.ruff] target-version = "py38" line-length = 120 + +[tool.ruff.lint] select = [ "A", "ARG", @@ -133,13 +135,13 @@ unfixable = [ "F401", ] -[tool.ruff.isort] +[tool.ruff.lint.isort] known-first-party = ["haystack_integrations"] -[tool.ruff.flake8-tidy-imports] +[tool.ruff.lint.flake8-tidy-imports] ban-relative-imports = "parents" -[tool.ruff.per-file-ignores] +[tool.ruff.lint.per-file-ignores] # Tests can use magic values, assertions, and relative imports "tests/**/*" = ["PLR2004", "S101", "TID252"] # examples can contain "print" commands diff --git a/integrations/qdrant/pyproject.toml b/integrations/qdrant/pyproject.toml index 225844f22..898fd2dcf 100644 --- a/integrations/qdrant/pyproject.toml +++ b/integrations/qdrant/pyproject.toml @@ -62,8 +62,8 @@ detached = true dependencies = ["black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243"] [tool.hatch.envs.lint.scripts] typing = "mypy --install-types --non-interactive --explicit-package-bases {args:src/ tests}" -style = ["ruff check {args:. --exclude tests/, examples/}", "black --check --diff {args:.}"] -fmt = ["black {args:.}", "ruff --fix {args:. --exclude tests/, examples/}", "style"] +style = ["ruff check {args:.}", "black --check --diff {args:.}"] +fmt = ["black {args:.}", "ruff check --fix {args:.}", "style"] all = ["style", "typing"] [tool.black] @@ -74,6 +74,8 @@ skip-string-normalization = true [tool.ruff] target-version = "py38" line-length = 120 + +[tool.ruff.lint] select = [ "A", "ARG", @@ -125,10 +127,10 @@ unfixable = [ "F401", ] -[tool.ruff.flake8-tidy-imports] +[tool.ruff.lint.flake8-tidy-imports] ban-relative-imports = "parents" -[tool.ruff.per-file-ignores] +[tool.ruff.lint.per-file-ignores] # Tests can use magic values, assertions, and relative imports "tests/**/*" = ["PLR2004", "S101", "TID252"] # examples can contain "print" commands diff --git a/integrations/ragas/pyproject.toml b/integrations/ragas/pyproject.toml index d9ae6ca02..dd56e35f6 100644 --- a/integrations/ragas/pyproject.toml +++ b/integrations/ragas/pyproject.toml @@ -66,7 +66,7 @@ dependencies = ["black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243"] [tool.hatch.envs.lint.scripts] typing = "mypy --install-types --non-interactive {args:src/}" style = ["ruff check {args:.}", "black --check --diff {args:.}"] -fmt = ["black {args:.}", "ruff --fix {args:.}", "style"] +fmt = ["black {args:.}", "ruff check --fix {args:.}", "style"] all = ["style", "typing"] [tool.black] @@ -77,6 +77,9 @@ skip-string-normalization = true [tool.ruff] target-version = "py38" line-length = 120 +exclude = ["example", "tests"] + +[tool.ruff.lint] select = [ "A", "ARG", @@ -128,15 +131,14 @@ unfixable = [ # Don't touch unused imports "F401", ] -extend-exclude = ["tests", "example"] -[tool.ruff.isort] -known-first-party = ["src"] +[tool.ruff.lint.isort] +known-first-party = ["haystack_integrations"] -[tool.ruff.flake8-tidy-imports] +[tool.ruff.lint.flake8-tidy-imports] ban-relative-imports = "all" -[tool.ruff.per-file-ignores] +[tool.ruff.lint.per-file-ignores] # Tests can use magic values, assertions, and relative imports "tests/**/*" = ["PLR2004", "S101", "TID252"] diff --git a/integrations/ragas/tests/test_evaluator.py b/integrations/ragas/tests/test_evaluator.py index fc8901c32..0f847ed0b 100644 --- a/integrations/ragas/tests/test_evaluator.py +++ b/integrations/ragas/tests/test_evaluator.py @@ -5,10 +5,11 @@ import pytest from datasets import Dataset from haystack import DeserializationError -from haystack_integrations.components.evaluators.ragas import RagasEvaluator, RagasMetric from ragas.evaluation import Result from ragas.metrics.base import Metric +from haystack_integrations.components.evaluators.ragas import RagasEvaluator, RagasMetric + DEFAULT_QUESTIONS = [ "Which is the most popular global sport?", "Who created the Python language?", diff --git a/integrations/snowflake/pyproject.toml b/integrations/snowflake/pyproject.toml index 68f9ec477..355e9d090 100644 --- a/integrations/snowflake/pyproject.toml +++ b/integrations/snowflake/pyproject.toml @@ -62,8 +62,8 @@ detached = true dependencies = ["black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243"] [tool.hatch.envs.lint.scripts] typing = "mypy --install-types --non-interactive --explicit-package-bases {args:src/ tests}" -style = ["ruff check {args:. --exclude tests/}", "black --check --diff {args:.}"] -fmt = ["black {args:.}", "ruff --fix {args:. --exclude tests/}", "style"] +style = ["ruff check {args:}", "black --check --diff {args:.}"] +fmt = ["black {args:.}", "ruff check --fix {args:}", "style"] all = ["style", "typing"] [tool.black] @@ -74,6 +74,8 @@ skip-string-normalization = true [tool.ruff] target-version = "py38" line-length = 120 + +[tool.ruff.lint] select = [ "A", "ARG", @@ -123,13 +125,13 @@ unfixable = [ "F401", ] -[tool.ruff.isort] -known-first-party = ["snowflake_haystack"] +[tool.ruff.lint.isort] +known-first-party = ["haystack_integrations"] -[tool.ruff.flake8-tidy-imports] +[tool.ruff.lint.flake8-tidy-imports] ban-relative-imports = "parents" -[tool.ruff.per-file-ignores] +[tool.ruff.lint.per-file-ignores] # Tests can use magic values, assertions, and relative imports "tests/**/*" = ["PLR2004", "S101", "TID252"] diff --git a/integrations/snowflake/tests/test_snowflake_table_retriever.py b/integrations/snowflake/tests/test_snowflake_table_retriever.py index 547f7e1b1..f5b8fee37 100644 --- a/integrations/snowflake/tests/test_snowflake_table_retriever.py +++ b/integrations/snowflake/tests/test_snowflake_table_retriever.py @@ -10,9 +10,9 @@ import pytest from dateutil.tz import tzlocal from haystack import Pipeline +from haystack.components.builders import PromptBuilder from haystack.components.converters import OutputAdapter from haystack.components.generators import OpenAIGenerator -from haystack.components.builders import PromptBuilder from haystack.utils import Secret from openai.types.chat import ChatCompletion, ChatCompletionMessage from openai.types.chat.chat_completion import Choice @@ -478,7 +478,10 @@ def test_to_dict_default(self, monkeypatch: MagicMock) -> None: data = component.to_dict() assert data == { - "type": "haystack_integrations.components.retrievers.snowflake.snowflake_table_retriever.SnowflakeTableRetriever", + "type": ( + "haystack_integrations.components.retrievers.snowflake." + "snowflake_table_retriever.SnowflakeTableRetriever" + ), "init_parameters": { "api_key": { "env_vars": ["SNOWFLAKE_API_KEY"], @@ -510,7 +513,10 @@ def test_to_dict_with_parameters(self, monkeypatch: MagicMock) -> None: data = component.to_dict() assert data == { - "type": "haystack_integrations.components.retrievers.snowflake.snowflake_table_retriever.SnowflakeTableRetriever", + "type": ( + "haystack_integrations.components.retrievers.snowflake." + "snowflake_table_retriever.SnowflakeTableRetriever" + ), "init_parameters": { "api_key": { "env_vars": ["SNOWFLAKE_API_KEY"], diff --git a/integrations/unstructured/pyproject.toml b/integrations/unstructured/pyproject.toml index 6811753d9..88bd463b2 100644 --- a/integrations/unstructured/pyproject.toml +++ b/integrations/unstructured/pyproject.toml @@ -65,10 +65,10 @@ dependencies = ["black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243"] [tool.hatch.envs.lint.scripts] typing = "mypy --install-types --non-interactive --explicit-package-bases {args:src/ tests}" style = [ - "ruff check {args:. --exclude tests/}", + "ruff check {args:}", "black --check --diff {args:.}", ] -fmt = ["black {args:.}", "ruff --fix {args:. --exclude tests/}", "style"] +fmt = ["black {args:.}", "ruff check --fix {args:}", "style"] all = ["style", "typing"] [tool.hatch.metadata] @@ -82,6 +82,8 @@ skip-string-normalization = true [tool.ruff] target-version = "py38" line-length = 120 + +[tool.ruff.lint] select = [ "A", "ARG", @@ -130,13 +132,13 @@ unfixable = [ "F401", ] -[tool.ruff.isort] -known-first-party = ["src"] +[tool.ruff.lint.isort] +known-first-party = ["haystack_integrations"] -[tool.ruff.flake8-tidy-imports] +[tool.ruff.lint.flake8-tidy-imports] ban-relative-imports = "parents" -[tool.ruff.per-file-ignores] +[tool.ruff.lint.per-file-ignores] # Tests can use magic values, assertions, and relative imports "tests/**/*" = ["PLR2004", "S101", "TID252"] diff --git a/integrations/unstructured/tests/test_converter.py b/integrations/unstructured/tests/test_converter.py index 5d1a6c091..063289b07 100644 --- a/integrations/unstructured/tests/test_converter.py +++ b/integrations/unstructured/tests/test_converter.py @@ -2,6 +2,7 @@ # # SPDX-License-Identifier: Apache-2.0 import pytest + from haystack_integrations.components.converters.unstructured import UnstructuredFileConverter diff --git a/integrations/weaviate/pyproject.toml b/integrations/weaviate/pyproject.toml index 624a06f1d..22d3a160d 100644 --- a/integrations/weaviate/pyproject.toml +++ b/integrations/weaviate/pyproject.toml @@ -65,8 +65,8 @@ detached = true dependencies = ["black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243"] [tool.hatch.envs.lint.scripts] typing = "mypy --install-types --non-interactive --explicit-package-bases {args:src/ tests}" -style = ["ruff check {args:. --exclude tests/}", "black --check --diff {args:.}"] -fmt = ["black {args:.}", "ruff --fix {args:. --exclude tests/}", "style"] +style = ["ruff check {args:}", "black --check --diff {args:.}"] +fmt = ["black {args:.}", "ruff check --fix {args:}", "style"] all = ["style", "typing"] [tool.black] @@ -77,7 +77,9 @@ skip-string-normalization = true [tool.ruff] target-version = "py38" line-length = 120 -lint.select = [ + +[tool.ruff.lint] +select = [ "A", "ARG", "B", @@ -104,7 +106,7 @@ lint.select = [ "W", "YTT", ] -lint.ignore = [ +ignore = [ # Allow non-abstract empty methods in abstract base classes "B027", # Allow boolean positional values in function calls, like `dict.get(... True)` @@ -120,13 +122,13 @@ lint.ignore = [ "PLR0913", "PLR0915", ] -lint.unfixable = [ +unfixable = [ # Don't touch unused imports "F401", ] [tool.ruff.lint.isort] -known-first-party = ["src"] +known-first-party = ["haystack_integrations"] [tool.ruff.lint.flake8-tidy-imports] ban-relative-imports = "parents" From 242e3c5d4eabaf969ea19308f70dd83b6bebffb3 Mon Sep 17 00:00:00 2001 From: Alper Date: Mon, 30 Sep 2024 11:11:15 +0200 Subject: [PATCH 014/229] feat: Chroma - defer the DB connection (#1107) * defer DB from chroma * added ensure_initialized * addressed comments * simplification and linting * refinements to docstrings --------- Co-authored-by: Stefano Fiorucci --- integrations/chroma/pyproject.toml | 2 + .../document_stores/chroma/document_store.py | 103 +++++++++++------- .../chroma/tests/test_document_store.py | 10 +- 3 files changed, 74 insertions(+), 41 deletions(-) diff --git a/integrations/chroma/pyproject.toml b/integrations/chroma/pyproject.toml index 2bffabfd8..27b204432 100644 --- a/integrations/chroma/pyproject.toml +++ b/integrations/chroma/pyproject.toml @@ -131,6 +131,8 @@ ignore = [ "PLR0915", # Ignore unused params "ARG002", + # Allow assertions + "S101", ] unfixable = [ # Don't touch unused imports diff --git a/integrations/chroma/src/haystack_integrations/document_stores/chroma/document_store.py b/integrations/chroma/src/haystack_integrations/document_stores/chroma/document_store.py index 359ace58d..990aa4c34 100644 --- a/integrations/chroma/src/haystack_integrations/document_stores/chroma/document_store.py +++ b/integrations/chroma/src/haystack_integrations/document_stores/chroma/document_store.py @@ -40,9 +40,8 @@ def __init__( **embedding_function_params, ): """ - Initializes the store. The __init__ constructor is not part of the Store Protocol - and the signature can be customized to your needs. For example, parameters needed - to set up a database client would be passed to this method. + Creates a new ChromaDocumentStore instance. + It is meant to be connected to a Chroma collection. Note: for the component to be part of a serializable pipeline, the __init__ parameters must be serializable, reason why we use a registry to configure the @@ -65,7 +64,6 @@ def __init__( :param metadata: a dictionary of chromadb collection parameters passed directly to chromadb's client method `create_collection`. If it contains the key `"hnsw:space"`, the value will take precedence over the `distance_function` parameter above. - :param embedding_function_params: additional parameters to pass to the embedding function. """ @@ -79,53 +77,61 @@ def __init__( # Store the params for marshalling self._collection_name = collection_name self._embedding_function = embedding_function + self._embedding_func = get_embedding_function(embedding_function, **embedding_function_params) self._embedding_function_params = embedding_function_params self._distance_function = distance_function + self._metadata = metadata + self._collection = None self._persist_path = persist_path self._host = host self._port = port - # Create the client instance - if persist_path and (host or port is not None): - error_message = ( - "You must specify `persist_path` for local persistent storage or, " - "alternatively, `host` and `port` for remote HTTP client connection. " - "You cannot specify both options." - ) - raise ValueError(error_message) - if host and port is not None: - # Remote connection via HTTP client - self._chroma_client = chromadb.HttpClient( - host=host, - port=port, - ) - elif persist_path is None: - # In-memory storage - self._chroma_client = chromadb.Client() - else: - # Local persistent storage - self._chroma_client = chromadb.PersistentClient(path=persist_path) + self._initialized = False - embedding_func = get_embedding_function(embedding_function, **embedding_function_params) + def _ensure_initialized(self): + if not self._initialized: + # Create the client instance + if self._persist_path and (self._host or self._port is not None): + error_message = ( + "You must specify `persist_path` for local persistent storage or, " + "alternatively, `host` and `port` for remote HTTP client connection. " + "You cannot specify both options." + ) + raise ValueError(error_message) + if self._host and self._port is not None: + # Remote connection via HTTP client + client = chromadb.HttpClient( + host=self._host, + port=self._port, + ) + elif self._persist_path is None: + # In-memory storage + client = chromadb.Client() + else: + # Local persistent storage + client = chromadb.PersistentClient(path=self._persist_path) - metadata = metadata or {} - if "hnsw:space" not in metadata: - metadata["hnsw:space"] = distance_function + self._metadata = self._metadata or {} + if "hnsw:space" not in self._metadata: + self._metadata["hnsw:space"] = self._distance_function - if collection_name in [c.name for c in self._chroma_client.list_collections()]: - self._collection = self._chroma_client.get_collection(collection_name, embedding_function=embedding_func) + if self._collection_name in [c.name for c in client.list_collections()]: + self._collection = client.get_collection(self._collection_name, embedding_function=self._embedding_func) - if metadata != self._collection.metadata: - logger.warning( - "Collection already exists. The `distance_function` and `metadata` parameters will be ignored." + if self._metadata != self._collection.metadata: + logger.warning( + "Collection already exists. " + "The `distance_function` and `metadata` parameters will be ignored." + ) + else: + self._collection = client.create_collection( + name=self._collection_name, + metadata=self._metadata, + embedding_function=self._embedding_func, ) - else: - self._collection = self._chroma_client.create_collection( - name=collection_name, - metadata=metadata, - embedding_function=embedding_func, - ) + + self._initialized = True def count_documents(self) -> int: """ @@ -133,6 +139,8 @@ def count_documents(self) -> int: :returns: how many documents are present in the document store. """ + self._ensure_initialized() + assert self._collection is not None return self._collection.count() def filter_documents(self, filters: Optional[Dict[str, Any]] = None) -> List[Document]: @@ -197,6 +205,9 @@ def filter_documents(self, filters: Optional[Dict[str, Any]] = None) -> List[Doc :param filters: the filters to apply to the document list. :returns: a list of Documents that match the given filters. """ + self._ensure_initialized() + assert self._collection is not None + if filters: chroma_filter = _convert_filters(filters) kwargs: Dict[str, Any] = {"where": chroma_filter.where} @@ -227,6 +238,9 @@ def write_documents(self, documents: List[Document], policy: DuplicatePolicy = D :returns: The number of documents written """ + self._ensure_initialized() + assert self._collection is not None + for doc in documents: if not isinstance(doc, Document): msg = "param 'documents' must contain a list of objects of type Document" @@ -280,8 +294,11 @@ def delete_documents(self, document_ids: List[str]) -> None: """ Deletes all documents with a matching document_ids from the document store. - :param document_ids: the object_ids to delete + :param document_ids: the document ids to delete """ + self._ensure_initialized() + assert self._collection is not None + self._collection.delete(ids=document_ids) def search(self, queries: List[str], top_k: int, filters: Optional[Dict[str, Any]] = None) -> List[List[Document]]: @@ -292,6 +309,9 @@ def search(self, queries: List[str], top_k: int, filters: Optional[Dict[str, Any :param filters: a dictionary of filters to apply to the search. Accepts filters in haystack format. :returns: matching documents for each query. """ + self._ensure_initialized() + assert self._collection is not None + if filters is None: results = self._collection.query( query_texts=queries, @@ -323,6 +343,9 @@ def search_embeddings( :returns: a list of lists of documents that match the given filters. """ + self._ensure_initialized() + assert self._collection is not None + if filters is None: results = self._collection.query( query_embeddings=query_embeddings, diff --git a/integrations/chroma/tests/test_document_store.py b/integrations/chroma/tests/test_document_store.py index d33086945..41491dc4d 100644 --- a/integrations/chroma/tests/test_document_store.py +++ b/integrations/chroma/tests/test_document_store.py @@ -98,7 +98,8 @@ def test_invalid_initialization_both_host_and_persist_path(self): Test that providing both host and persist_path raises an error. """ with pytest.raises(ValueError): - ChromaDocumentStore(persist_path="./path/to/local/store", host="localhost") + store = ChromaDocumentStore(persist_path="./path/to/local/store", host="localhost") + store._ensure_initialized() def test_delete_empty(self, document_store: ChromaDocumentStore): """ @@ -207,6 +208,7 @@ def test_same_collection_name_reinitialization(self): @pytest.mark.integration def test_distance_metric_initialization(self): store = ChromaDocumentStore("test_2", distance_function="cosine") + store._ensure_initialized() assert store._collection.metadata["hnsw:space"] == "cosine" with pytest.raises(ValueError): @@ -215,9 +217,11 @@ def test_distance_metric_initialization(self): @pytest.mark.integration def test_distance_metric_reinitialization(self, caplog): store = ChromaDocumentStore("test_4", distance_function="cosine") + store._ensure_initialized() with caplog.at_level(logging.WARNING): new_store = ChromaDocumentStore("test_4", distance_function="ip") + new_store._ensure_initialized() assert ( "Collection already exists. The `distance_function` and `metadata` parameters will be ignored." @@ -238,6 +242,8 @@ def test_metadata_initialization(self, caplog): "hnsw:M": 103, }, ) + store._ensure_initialized() + assert store._collection.metadata["hnsw:space"] == "ip" assert store._collection.metadata["hnsw:search_ef"] == 101 assert store._collection.metadata["hnsw:construction_ef"] == 102 @@ -254,6 +260,8 @@ def test_metadata_initialization(self, caplog): }, ) + new_store._ensure_initialized() + assert ( "Collection already exists. The `distance_function` and `metadata` parameters will be ignored." in caplog.text From e8b71727165b248d34a0c05fcaa48c71dd92dca6 Mon Sep 17 00:00:00 2001 From: HaystackBot Date: Mon, 30 Sep 2024 09:17:04 +0000 Subject: [PATCH 015/229] Update the changelog --- integrations/chroma/CHANGELOG.md | 28 +++++++++++++++++++++++----- 1 file changed, 23 insertions(+), 5 deletions(-) diff --git a/integrations/chroma/CHANGELOG.md b/integrations/chroma/CHANGELOG.md index f6a23d84a..89ab0fea5 100644 --- a/integrations/chroma/CHANGELOG.md +++ b/integrations/chroma/CHANGELOG.md @@ -1,5 +1,28 @@ # Changelog +## [integrations/chroma-v0.22.0] - 2024-09-30 + +### 🚀 Features + +- Chroma - allow remote HTTP connection (#1094) +- Chroma - defer the DB connection (#1107) + +### 🐛 Bug Fixes + +- Fix chroma linting; rm numpy (#1063) + +Co-authored-by: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com> +- Filters in chroma integration (#1072) + +### 🧪 Testing + +- Do not retry tests in `hatch run test` command (#954) + +### ⚙️ Miscellaneous Tasks + +- Chroma - ruff update, don't ruff tests (#983) +- Update ruff linting scripts and settings (#1105) + ## [integrations/chroma-v0.21.1] - 2024-07-17 ### 🐛 Bug Fixes @@ -76,8 +99,6 @@ This PR will also push the docs to Readme - Fix project urls (#96) - - ### 🚜 Refactor - Use `hatch_vcs` to manage integrations versioning (#103) @@ -88,13 +109,10 @@ This PR will also push the docs to Readme - Fix import and increase version (#77) - - ## [integrations/chroma-v0.8.0] - 2023-12-04 ### 🐛 Bug Fixes - Fix license headers - From 5c2e61b7c8cfecc6c5acf0442738de06a24a2648 Mon Sep 17 00:00:00 2001 From: Stefano Fiorucci Date: Mon, 30 Sep 2024 12:58:53 +0200 Subject: [PATCH 016/229] chroma: empty filters should behave as no filters (#1117) --- .../document_stores/chroma/document_store.py | 4 ++-- integrations/chroma/tests/test_document_store.py | 4 ++++ 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/integrations/chroma/src/haystack_integrations/document_stores/chroma/document_store.py b/integrations/chroma/src/haystack_integrations/document_stores/chroma/document_store.py index 990aa4c34..6a83937a4 100644 --- a/integrations/chroma/src/haystack_integrations/document_stores/chroma/document_store.py +++ b/integrations/chroma/src/haystack_integrations/document_stores/chroma/document_store.py @@ -312,7 +312,7 @@ def search(self, queries: List[str], top_k: int, filters: Optional[Dict[str, Any self._ensure_initialized() assert self._collection is not None - if filters is None: + if not filters: results = self._collection.query( query_texts=queries, n_results=top_k, @@ -346,7 +346,7 @@ def search_embeddings( self._ensure_initialized() assert self._collection is not None - if filters is None: + if not filters: results = self._collection.query( query_embeddings=query_embeddings, n_results=top_k, diff --git a/integrations/chroma/tests/test_document_store.py b/integrations/chroma/tests/test_document_store.py index 41491dc4d..f386b44ba 100644 --- a/integrations/chroma/tests/test_document_store.py +++ b/integrations/chroma/tests/test_document_store.py @@ -137,6 +137,10 @@ def test_search(self): assert isinstance(doc.embedding, list) assert all(isinstance(el, float) for el in doc.embedding) + # check that empty filters behave as no filters + result_empty_filters = document_store.search(["Third"], filters={}, top_k=1) + assert result == result_empty_filters + def test_write_documents_unsupported_meta_values(self, document_store: ChromaDocumentStore): """ Unsupported meta values should be removed from the documents before writing them to the database From 60dbc8c36dfd84f1f8bd4ed3280ce5d4d40584c1 Mon Sep 17 00:00:00 2001 From: HaystackBot Date: Mon, 30 Sep 2024 11:00:17 +0000 Subject: [PATCH 017/229] Update the changelog --- integrations/chroma/CHANGELOG.md | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/integrations/chroma/CHANGELOG.md b/integrations/chroma/CHANGELOG.md index 89ab0fea5..c129d00ae 100644 --- a/integrations/chroma/CHANGELOG.md +++ b/integrations/chroma/CHANGELOG.md @@ -1,5 +1,11 @@ # Changelog +## [integrations/chroma-v0.22.1] - 2024-09-30 + +### Chroma + +- Empty filters should behave as no filters (#1117) + ## [integrations/chroma-v0.22.0] - 2024-09-30 ### 🚀 Features From 97094f7ecf9755d4128c6c7d09295ec8b91fa78d Mon Sep 17 00:00:00 2001 From: Stefano Fiorucci Date: Mon, 30 Sep 2024 18:31:04 +0200 Subject: [PATCH 018/229] chore: unpin `llama-cpp-python` (#1115) * unpin llamacpp * progress * fixes * remove pin * rm generation_kwargs * rm print * adjust --- integrations/llama_cpp/pyproject.toml | 2 +- integrations/llama_cpp/tests/test_chat_generator.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/integrations/llama_cpp/pyproject.toml b/integrations/llama_cpp/pyproject.toml index 673df575a..acf42d958 100644 --- a/integrations/llama_cpp/pyproject.toml +++ b/integrations/llama_cpp/pyproject.toml @@ -26,7 +26,7 @@ classifiers = [ "Programming Language :: Python :: Implementation :: CPython", "Programming Language :: Python :: Implementation :: PyPy", ] -dependencies = ["haystack-ai", "llama-cpp-python>=0.2.87,<0.3.0"] +dependencies = ["haystack-ai", "llama-cpp-python>=0.2.87"] [project.urls] Documentation = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/llama_cpp#readme" diff --git a/integrations/llama_cpp/tests/test_chat_generator.py b/integrations/llama_cpp/tests/test_chat_generator.py index 1d4c9cf82..802fe9128 100644 --- a/integrations/llama_cpp/tests/test_chat_generator.py +++ b/integrations/llama_cpp/tests/test_chat_generator.py @@ -342,7 +342,7 @@ def generator(self, model_path, capsys): hf_tokenizer_path = "meetkai/functionary-small-v2.4-GGUF" generator = LlamaCppChatGenerator( model=model_path, - n_ctx=8192, + n_ctx=512, n_batch=512, model_kwargs={ "chat_format": "functionary-v2", @@ -399,7 +399,6 @@ def test_function_call_and_execute(self, generator): "type": "string", "description": "The city and state, e.g. San Francisco, CA", }, - "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}, }, "required": ["location"], }, @@ -407,7 +406,8 @@ def test_function_call_and_execute(self, generator): } ] - response = generator.run(messages=messages, generation_kwargs={"tools": tools}) + tool_choice = {"type": "function", "function": {"name": "get_current_temperature"}} + response = generator.run(messages=messages, generation_kwargs={"tools": tools, "tool_choice": tool_choice}) available_functions = { "get_current_temperature": self.get_current_temperature, From 3d6bb776bb500de0e51c839633082353eebbc784 Mon Sep 17 00:00:00 2001 From: Madeesh Kannan Date: Tue, 1 Oct 2024 13:28:50 +0200 Subject: [PATCH 019/229] feat: Raise error when attempting to embed empty documents/strings with Nvidia embedders (#1118) * feat: Raise error when attempting to embed empty documents/strings with Nvidia embedders * feat: Improve Nim backend error handling --- .../embedders/nvidia/document_embedder.py | 5 ++ .../embedders/nvidia/text_embedder.py | 3 + .../utils/nvidia/nim_backend.py | 84 +++++++++++-------- .../nvidia/tests/test_document_embedder.py | 11 +++ .../nvidia/tests/test_text_embedder.py | 11 +++ 5 files changed, 78 insertions(+), 36 deletions(-) diff --git a/integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/document_embedder.py b/integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/document_embedder.py index 3e911e4f4..d746a75f4 100644 --- a/integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/document_embedder.py +++ b/integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/document_embedder.py @@ -231,6 +231,11 @@ def run(self, documents: List[Document]): ) raise TypeError(msg) + for doc in documents: + if not doc.content: + msg = f"Document '{doc.id}' has no content to embed." + raise ValueError(msg) + texts_to_embed = self._prepare_texts_to_embed(documents) embeddings, metadata = self._embed_batch(texts_to_embed, self.batch_size) for doc, emb in zip(documents, embeddings): diff --git a/integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/text_embedder.py b/integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/text_embedder.py index 0387c32b7..22bed8197 100644 --- a/integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/text_embedder.py +++ b/integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/text_embedder.py @@ -175,6 +175,9 @@ def run(self, text: str): "In case you want to embed a list of Documents, please use the NvidiaDocumentEmbedder." ) raise TypeError(msg) + elif not text: + msg = "Cannot embed an empty string." + raise ValueError(msg) assert self.backend is not None text_to_embed = self.prefix + text + self.suffix diff --git a/integrations/nvidia/src/haystack_integrations/utils/nvidia/nim_backend.py b/integrations/nvidia/src/haystack_integrations/utils/nvidia/nim_backend.py index 0d1f57e5c..cbb6b7c3f 100644 --- a/integrations/nvidia/src/haystack_integrations/utils/nvidia/nim_backend.py +++ b/integrations/nvidia/src/haystack_integrations/utils/nvidia/nim_backend.py @@ -50,16 +50,20 @@ def __init__( def embed(self, texts: List[str]) -> Tuple[List[List[float]], Dict[str, Any]]: url = f"{self.api_url}/embeddings" - res = self.session.post( - url, - json={ - "model": self.model, - "input": texts, - **self.model_kwargs, - }, - timeout=REQUEST_TIMEOUT, - ) - res.raise_for_status() + try: + res = self.session.post( + url, + json={ + "model": self.model, + "input": texts, + **self.model_kwargs, + }, + timeout=REQUEST_TIMEOUT, + ) + res.raise_for_status() + except requests.HTTPError as e: + msg = f"Failed to query embedding endpoint: Error - {e.response.text}" + raise ValueError(msg) from e data = res.json() # Sort the embeddings by index, we don't know whether they're out of order or not @@ -73,21 +77,25 @@ def generate(self, prompt: str) -> Tuple[List[str], List[Dict[str, Any]]]: # This is the same for local containers and the cloud API. url = f"{self.api_url}/chat/completions" - res = self.session.post( - url, - json={ - "model": self.model, - "messages": [ - { - "role": "user", - "content": prompt, - }, - ], - **self.model_kwargs, - }, - timeout=REQUEST_TIMEOUT, - ) - res.raise_for_status() + try: + res = self.session.post( + url, + json={ + "model": self.model, + "messages": [ + { + "role": "user", + "content": prompt, + }, + ], + **self.model_kwargs, + }, + timeout=REQUEST_TIMEOUT, + ) + res.raise_for_status() + except requests.HTTPError as e: + msg = f"Failed to query chat completion endpoint: Error - {e.response.text}" + raise ValueError(msg) from e completions = res.json() choices = completions["choices"] @@ -139,17 +147,21 @@ def rank( ) -> List[Dict[str, Any]]: url = endpoint or f"{self.api_url}/ranking" - res = self.session.post( - url, - json={ - "model": self.model, - "query": {"text": query}, - "passages": [{"text": doc.content} for doc in documents], - **self.model_kwargs, - }, - timeout=REQUEST_TIMEOUT, - ) - res.raise_for_status() + try: + res = self.session.post( + url, + json={ + "model": self.model, + "query": {"text": query}, + "passages": [{"text": doc.content} for doc in documents], + **self.model_kwargs, + }, + timeout=REQUEST_TIMEOUT, + ) + res.raise_for_status() + except requests.HTTPError as e: + msg = f"Failed to rank endpoint: Error - {e.response.text}" + raise ValueError(msg) from e data = res.json() assert "rankings" in data, f"Expected 'rankings' in response, got {data}" diff --git a/integrations/nvidia/tests/test_document_embedder.py b/integrations/nvidia/tests/test_document_embedder.py index bef0f996e..db69053e7 100644 --- a/integrations/nvidia/tests/test_document_embedder.py +++ b/integrations/nvidia/tests/test_document_embedder.py @@ -326,6 +326,17 @@ def test_run_wrong_input_format(self): with pytest.raises(TypeError, match="NvidiaDocumentEmbedder expects a list of Documents as input"): embedder.run(documents=list_integers_input) + def test_run_empty_document(self): + model = "playground_nvolveqa_40k" + api_key = Secret.from_token("fake-api-key") + embedder = NvidiaDocumentEmbedder(model, api_key=api_key) + + embedder.warm_up() + embedder.backend = MockBackend(model=model, api_key=api_key) + + with pytest.raises(ValueError, match="no content to embed"): + embedder.run(documents=[Document(content="")]) + def test_run_on_empty_list(self): model = "playground_nvolveqa_40k" api_key = Secret.from_token("fake-api-key") diff --git a/integrations/nvidia/tests/test_text_embedder.py b/integrations/nvidia/tests/test_text_embedder.py index 7c8428cc2..8690de6b1 100644 --- a/integrations/nvidia/tests/test_text_embedder.py +++ b/integrations/nvidia/tests/test_text_embedder.py @@ -147,6 +147,17 @@ def test_run_wrong_input_format(self): with pytest.raises(TypeError, match="NvidiaTextEmbedder expects a string as an input"): embedder.run(text=list_integers_input) + def test_run_empty_string(self): + model = "playground_nvolveqa_40k" + api_key = Secret.from_token("fake-api-key") + embedder = NvidiaTextEmbedder(model, api_key=api_key) + + embedder.warm_up() + embedder.backend = MockBackend(model=model, api_key=api_key) + + with pytest.raises(ValueError, match="empty string"): + embedder.run(text="") + @pytest.mark.skipif( not os.environ.get("NVIDIA_NIM_EMBEDDER_MODEL", None) or not os.environ.get("NVIDIA_NIM_ENDPOINT_URL", None), reason="Export an env var called NVIDIA_NIM_EMBEDDER_MODEL containing the hosted model name and " From 3511eb39176c40186f628721fcd6407e7bcf5f76 Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Tue, 1 Oct 2024 16:45:32 +0200 Subject: [PATCH 020/229] Langfuse: add invocation_context to identify traces (#1089) * Langfuse: add invocation_context to identify traces * Lint * Add unit test * Log debug invocation context * Lint + format --- integrations/langfuse/example/chat.py | 12 +++++++++++- .../connectors/langfuse/langfuse_connector.py | 16 ++++++++++++++-- integrations/langfuse/tests/test_tracing.py | 9 ++++++++- 3 files changed, 33 insertions(+), 4 deletions(-) diff --git a/integrations/langfuse/example/chat.py b/integrations/langfuse/example/chat.py index 0d9c42787..2308ed1f4 100644 --- a/integrations/langfuse/example/chat.py +++ b/integrations/langfuse/example/chat.py @@ -49,6 +49,16 @@ ChatMessage.from_user("Tell me about {{location}}"), ] - response = pipe.run(data={"prompt_builder": {"template_variables": {"location": "Berlin"}, "template": messages}}) + response = pipe.run( + data={ + "prompt_builder": { + "template_variables": {"location": "Berlin"}, + "template": messages, + }, + "tracer": { + "invocation_context": {"some_key": "some_value"}, + }, + } + ) print(response["llm"]["replies"][0]) print(response["tracer"]["trace_url"]) diff --git a/integrations/langfuse/src/haystack_integrations/components/connectors/langfuse/langfuse_connector.py b/integrations/langfuse/src/haystack_integrations/components/connectors/langfuse/langfuse_connector.py index 29f58d722..ff0a7c6ed 100644 --- a/integrations/langfuse/src/haystack_integrations/components/connectors/langfuse/langfuse_connector.py +++ b/integrations/langfuse/src/haystack_integrations/components/connectors/langfuse/langfuse_connector.py @@ -1,8 +1,12 @@ -from haystack import component, tracing +from typing import Any, Dict, Optional + +from haystack import component, logging, tracing from haystack_integrations.tracing.langfuse import LangfuseTracer from langfuse import Langfuse +logger = logging.getLogger(__name__) + @component class LangfuseConnector: @@ -105,12 +109,20 @@ def __init__(self, name: str, public: bool = False): tracing.enable_tracing(self.tracer) @component.output_types(name=str, trace_url=str) - def run(self): + def run(self, invocation_context: Optional[Dict[str, Any]] = None): """ Runs the LangfuseConnector component. + :param invocation_context: A dictionary with additional context for the invocation. This parameter + is useful when users want to mark this particular invocation with additional information, e.g. + a run id from their own execution framework, user id, etc. These key-value pairs are then visible + in the Langfuse traces. :returns: A dictionary with the following keys: - `name`: The name of the tracing component. - `trace_url`: The URL to the tracing data. """ + logger.debug( + "Langfuse tracer invoked with the following context: '{invocation_context}'", + invocation_context=invocation_context, + ) return {"name": self.name, "trace_url": self.tracer.get_trace_url()} diff --git a/integrations/langfuse/tests/test_tracing.py b/integrations/langfuse/tests/test_tracing.py index 936064e0a..657b6eae1 100644 --- a/integrations/langfuse/tests/test_tracing.py +++ b/integrations/langfuse/tests/test_tracing.py @@ -43,7 +43,12 @@ def test_tracing_integration(llm_class, env_var, expected_trace): ChatMessage.from_user("Tell me about {{location}}"), ] - response = pipe.run(data={"prompt_builder": {"template_variables": {"location": "Berlin"}, "template": messages}}) + response = pipe.run( + data={ + "prompt_builder": {"template_variables": {"location": "Berlin"}, "template": messages}, + "tracer": {"invocation_context": {"user_id": "user_42"}}, + } + ) assert "Berlin" in response["llm"]["replies"][0].content assert response["tracer"]["trace_url"] @@ -65,5 +70,7 @@ def test_tracing_integration(llm_class, env_var, expected_trace): assert expected_trace in str(response.content) # check if the trace contains the expected generation span assert "GENERATION" in str(response.content) + # check if the trace contains the expected user_id + assert "user_42" in str(response.content) except requests.exceptions.RequestException as e: pytest.fail(f"Failed to retrieve data from Langfuse API: {e}") From dbd844cd138e2fd878a6c0d8bd64c4cc01faffdf Mon Sep 17 00:00:00 2001 From: HaystackBot Date: Tue, 1 Oct 2024 14:47:21 +0000 Subject: [PATCH 021/229] Update the changelog --- integrations/langfuse/CHANGELOG.md | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/integrations/langfuse/CHANGELOG.md b/integrations/langfuse/CHANGELOG.md index ccd68ded3..29be7f838 100644 --- a/integrations/langfuse/CHANGELOG.md +++ b/integrations/langfuse/CHANGELOG.md @@ -1,5 +1,15 @@ # Changelog +## [integrations/langfuse-v0.5.0] - 2024-10-01 + +### ⚙️ Miscellaneous Tasks + +- Update ruff linting scripts and settings (#1105) + +### Langfuse + +- Add invocation_context to identify traces (#1089) + ## [integrations/langfuse-v0.4.0] - 2024-09-17 ### 🚀 Features From cc5defbb6fd1fd4f3e44ade6d6c2d7c617e67ad1 Mon Sep 17 00:00:00 2001 From: Stefano Fiorucci Date: Fri, 4 Oct 2024 11:04:55 +0200 Subject: [PATCH 022/229] pin version (#1124) --- integrations/google_vertex/pyproject.toml | 2 +- .../google_vertex/tests/chat/test_gemini.py | 18 ++++++++++++------ .../google_vertex/tests/test_gemini.py | 18 ++++++++++++------ 3 files changed, 25 insertions(+), 13 deletions(-) diff --git a/integrations/google_vertex/pyproject.toml b/integrations/google_vertex/pyproject.toml index a0cefbcd4..51bc4ffd7 100644 --- a/integrations/google_vertex/pyproject.toml +++ b/integrations/google_vertex/pyproject.toml @@ -22,7 +22,7 @@ classifiers = [ "Programming Language :: Python :: Implementation :: CPython", "Programming Language :: Python :: Implementation :: PyPy", ] -dependencies = ["haystack-ai", "google-cloud-aiplatform>=1.38", "pyarrow>3"] +dependencies = ["haystack-ai", "google-cloud-aiplatform>=1.61", "pyarrow>3"] [project.urls] Documentation = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/google_vertex#readme" diff --git a/integrations/google_vertex/tests/chat/test_gemini.py b/integrations/google_vertex/tests/chat/test_gemini.py index 87b43d66b..6b1308dab 100644 --- a/integrations/google_vertex/tests/chat/test_gemini.py +++ b/integrations/google_vertex/tests/chat/test_gemini.py @@ -22,11 +22,11 @@ name="get_current_weather", description="Get the current weather in a given location", parameters={ - "type_": "OBJECT", + "type": "object", "properties": { - "location": {"type_": "STRING", "description": "The city and state, e.g. San Francisco, CA"}, + "location": {"type": "string", "description": "The city and state, e.g. San Francisco, CA"}, "unit": { - "type_": "STRING", + "type": "string", "enum": [ "celsius", "fahrenheit", @@ -238,13 +238,19 @@ def test_from_dict_with_param(_mock_vertexai_init, _mock_generative_model): "name": "get_current_weather", "description": "Get the current weather in a given location", "parameters": { - "type_": "OBJECT", + "type": "object", "properties": { "location": { - "type_": "STRING", + "type": "string", "description": "The city and state, e.g. San Francisco, CA", }, - "unit": {"type_": "STRING", "enum": ["celsius", "fahrenheit"]}, + "unit": { + "type": "string", + "enum": [ + "celsius", + "fahrenheit", + ], + }, }, "required": ["location"], }, diff --git a/integrations/google_vertex/tests/test_gemini.py b/integrations/google_vertex/tests/test_gemini.py index 1543f3ccf..9ec3529d7 100644 --- a/integrations/google_vertex/tests/test_gemini.py +++ b/integrations/google_vertex/tests/test_gemini.py @@ -18,11 +18,11 @@ name="get_current_weather", description="Get the current weather in a given location", parameters={ - "type_": "OBJECT", + "type": "object", "properties": { - "location": {"type_": "STRING", "description": "The city and state, e.g. San Francisco, CA"}, + "location": {"type": "string", "description": "The city and state, e.g. San Francisco, CA"}, "unit": { - "type_": "STRING", + "type": "string", "enum": [ "celsius", "fahrenheit", @@ -226,13 +226,19 @@ def test_from_dict_with_param(_mock_vertexai_init, _mock_generative_model): { "name": "get_current_weather", "parameters": { - "type_": "OBJECT", + "type": "object", "properties": { - "unit": {"type_": "STRING", "enum": ["celsius", "fahrenheit"]}, "location": { - "type_": "STRING", + "type": "string", "description": "The city and state, e.g. San Francisco, CA", }, + "unit": { + "type": "string", + "enum": [ + "celsius", + "fahrenheit", + ], + }, }, "required": ["location"], }, From cf52ce94c9a1ae7f33bbde14c8639e2058262707 Mon Sep 17 00:00:00 2001 From: HaystackBot Date: Fri, 4 Oct 2024 09:41:18 +0000 Subject: [PATCH 023/229] Update the changelog --- integrations/google_vertex/CHANGELOG.md | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/integrations/google_vertex/CHANGELOG.md b/integrations/google_vertex/CHANGELOG.md index 17a730b60..23ab51b3d 100644 --- a/integrations/google_vertex/CHANGELOG.md +++ b/integrations/google_vertex/CHANGELOG.md @@ -1,11 +1,13 @@ # Changelog -## [unreleased] +## [integrations/google_vertex-v2.1.0] - 2024-10-04 ### 🚀 Features - Enable streaming for VertexAIGeminiChatGenerator (#1014) - Add tests for VertexAIGeminiGenerator and enable streaming (#1012) +- Add chatrole tests and meta for GeminiChatGenerators (#1090) +- Add custom params to VertexAIGeminiGenerator and VertexAIGeminiChatGenerator (#1100) ### 🐛 Bug Fixes @@ -21,6 +23,7 @@ - Retry tests to reduce flakyness (#836) - Update ruff invocation to include check parameter (#853) +- Update ruff linting scripts and settings (#1105) ## [integrations/google_vertex-v1.1.0] - 2024-03-28 From 62483b78232a75e729a5325cf774673ca0906c1e Mon Sep 17 00:00:00 2001 From: 1greentangerine <158560711+1greentangerine@users.noreply.github.com> Date: Fri, 4 Oct 2024 12:19:06 +0200 Subject: [PATCH 024/229] modify regex to allow cross-region inference in bedrock (#1120) * modify regex to allow cross-region inference in bedrock (only possible for claude models) * add tests for multi-region inference with claude models --- .../components/generators/amazon_bedrock/chat/chat_generator.py | 2 +- .../components/generators/amazon_bedrock/generator.py | 2 +- integrations/amazon_bedrock/tests/test_chat_generator.py | 2 ++ integrations/amazon_bedrock/tests/test_generator.py | 2 ++ 4 files changed, 6 insertions(+), 2 deletions(-) diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py index 988452a97..e1732646a 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py @@ -58,7 +58,7 @@ class AmazonBedrockChatGenerator: """ SUPPORTED_MODEL_PATTERNS: ClassVar[Dict[str, Type[BedrockModelChatAdapter]]] = { - r"anthropic.claude.*": AnthropicClaudeChatAdapter, + r"(.+\.)?anthropic.claude.*": AnthropicClaudeChatAdapter, r"meta.llama2.*": MetaLlama2ChatAdapter, r"mistral.*": MistralChatAdapter, } diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/generator.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/generator.py index 6ef0a4765..1edde3526 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/generator.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/generator.py @@ -69,7 +69,7 @@ class AmazonBedrockGenerator: r"ai21.j2.*": AI21LabsJurassic2Adapter, r"cohere.command-[^r].*": CohereCommandAdapter, r"cohere.command-r.*": CohereCommandRAdapter, - r"anthropic.claude.*": AnthropicClaudeAdapter, + r"(.+\.)?anthropic.claude.*": AnthropicClaudeAdapter, r"meta.llama.*": MetaLlamaAdapter, r"mistral.*": MistralAdapter, } diff --git a/integrations/amazon_bedrock/tests/test_chat_generator.py b/integrations/amazon_bedrock/tests/test_chat_generator.py index a455d2c93..49abc0979 100644 --- a/integrations/amazon_bedrock/tests/test_chat_generator.py +++ b/integrations/amazon_bedrock/tests/test_chat_generator.py @@ -254,6 +254,8 @@ def test_long_prompt_is_not_truncated_when_truncate_false(mock_boto3_session): [ ("anthropic.claude-v1", AnthropicClaudeChatAdapter), ("anthropic.claude-v2", AnthropicClaudeChatAdapter), + ("eu.anthropic.claude-v1", AnthropicClaudeChatAdapter), # cross-region inference + ("us.anthropic.claude-v2", AnthropicClaudeChatAdapter), # cross-region inference ("anthropic.claude-instant-v1", AnthropicClaudeChatAdapter), ("anthropic.claude-super-v5", AnthropicClaudeChatAdapter), # artificial ("meta.llama2-13b-chat-v1", MetaLlama2ChatAdapter), diff --git a/integrations/amazon_bedrock/tests/test_generator.py b/integrations/amazon_bedrock/tests/test_generator.py index f0233888c..61ae9d6b4 100644 --- a/integrations/amazon_bedrock/tests/test_generator.py +++ b/integrations/amazon_bedrock/tests/test_generator.py @@ -231,6 +231,8 @@ def test_long_prompt_is_not_truncated_when_truncate_false(mock_boto3_session): [ ("anthropic.claude-v1", AnthropicClaudeAdapter), ("anthropic.claude-v2", AnthropicClaudeAdapter), + ("eu.anthropic.claude-v1", AnthropicClaudeAdapter), # cross-region inference + ("us.anthropic.claude-v2", AnthropicClaudeAdapter), # cross-region inference ("anthropic.claude-instant-v1", AnthropicClaudeAdapter), ("anthropic.claude-super-v5", AnthropicClaudeAdapter), # artificial ("cohere.command-text-v14", CohereCommandAdapter), From ddb0c63abc3c0dbdc5f7a4dcad05dc72992a79db Mon Sep 17 00:00:00 2001 From: HaystackBot Date: Fri, 4 Oct 2024 10:20:34 +0000 Subject: [PATCH 025/229] Update the changelog --- integrations/amazon_bedrock/CHANGELOG.md | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/integrations/amazon_bedrock/CHANGELOG.md b/integrations/amazon_bedrock/CHANGELOG.md index 417c661fe..e7b0db667 100644 --- a/integrations/amazon_bedrock/CHANGELOG.md +++ b/integrations/amazon_bedrock/CHANGELOG.md @@ -1,15 +1,20 @@ # Changelog -## [unreleased] +## [integrations/amazon_bedrock-v1.0.3] - 2024-10-04 ### 🐛 Bug Fixes - *(Bedrock)* Allow tools kwargs for AWS Bedrock Claude model (#976) +- Chat roles for model responses in chat generators (#1030) ### 🚜 Refactor - Remove usage of deprecated `ChatMessage.to_openai_format` (#1007) +### ⚙️ Miscellaneous Tasks + +- Update ruff linting scripts and settings (#1105) + ## [integrations/amazon_bedrock-v1.0.1] - 2024-08-19 ### 🚀 Features From 42cc2ae8557a68ae61f22117dbacea46fc39ff67 Mon Sep 17 00:00:00 2001 From: Alper Date: Mon, 7 Oct 2024 10:55:49 +0200 Subject: [PATCH 026/229] feat: introduce `model_kwargs` in Sparse Embedders (can be used for BM25 parameters) (#1126) * add support for bm25 in FastEmbed integration * agnostic support for model config params * Apply suggestions from code review Co-authored-by: Stefano Fiorucci * add future readability --------- Co-authored-by: Stefano Fiorucci --- .../embedding_backend/fastembed_backend.py | 20 +++++-- .../fastembed_sparse_document_embedder.py | 5 ++ .../fastembed_sparse_text_embedder.py | 5 ++ .../fastembed/tests/test_fastembed_backend.py | 26 +++++++++ ...test_fastembed_sparse_document_embedder.py | 54 +++++++++++++++++- .../test_fastembed_sparse_text_embedder.py | 56 ++++++++++++++++++- 6 files changed, 160 insertions(+), 6 deletions(-) diff --git a/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/embedding_backend/fastembed_backend.py b/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/embedding_backend/fastembed_backend.py index 66f797549..3a68abcfb 100644 --- a/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/embedding_backend/fastembed_backend.py +++ b/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/embedding_backend/fastembed_backend.py @@ -1,4 +1,4 @@ -from typing import ClassVar, Dict, List, Optional +from typing import Any, ClassVar, Dict, List, Optional from haystack.dataclasses.sparse_embedding import SparseEmbedding from tqdm import tqdm @@ -73,14 +73,19 @@ def get_embedding_backend( cache_dir: Optional[str] = None, threads: Optional[int] = None, local_files_only: bool = False, + model_kwargs: Optional[Dict[str, Any]] = None, ): - embedding_backend_id = f"{model_name}{cache_dir}{threads}" + embedding_backend_id = f"{model_name}{cache_dir}{threads}{local_files_only}{model_kwargs}" if embedding_backend_id in _FastembedSparseEmbeddingBackendFactory._instances: return _FastembedSparseEmbeddingBackendFactory._instances[embedding_backend_id] embedding_backend = _FastembedSparseEmbeddingBackend( - model_name=model_name, cache_dir=cache_dir, threads=threads, local_files_only=local_files_only + model_name=model_name, + cache_dir=cache_dir, + threads=threads, + local_files_only=local_files_only, + model_kwargs=model_kwargs, ) _FastembedSparseEmbeddingBackendFactory._instances[embedding_backend_id] = embedding_backend return embedding_backend @@ -97,9 +102,16 @@ def __init__( cache_dir: Optional[str] = None, threads: Optional[int] = None, local_files_only: bool = False, + model_kwargs: Optional[Dict[str, Any]] = None, ): + model_kwargs = model_kwargs or {} + self.model = SparseTextEmbedding( - model_name=model_name, cache_dir=cache_dir, threads=threads, local_files_only=local_files_only + model_name=model_name, + cache_dir=cache_dir, + threads=threads, + local_files_only=local_files_only, + **model_kwargs, ) def embed(self, data: List[List[str]], progress_bar=True, **kwargs) -> List[SparseEmbedding]: diff --git a/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/fastembed_sparse_document_embedder.py b/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/fastembed_sparse_document_embedder.py index 4b72389fa..f79f08c90 100644 --- a/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/fastembed_sparse_document_embedder.py +++ b/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/fastembed_sparse_document_embedder.py @@ -62,6 +62,7 @@ def __init__( local_files_only: bool = False, meta_fields_to_embed: Optional[List[str]] = None, embedding_separator: str = "\n", + model_kwargs: Optional[Dict[str, Any]] = None, ): """ Create an FastembedDocumentEmbedder component. @@ -81,6 +82,7 @@ def __init__( :param local_files_only: If `True`, only use the model files in the `cache_dir`. :param meta_fields_to_embed: List of meta fields that should be embedded along with the Document content. :param embedding_separator: Separator used to concatenate the meta fields to the Document content. + :param model_kwargs: Dictionary containing model parameters such as `k`, `b`, `avg_len`, `language`. """ self.model_name = model @@ -92,6 +94,7 @@ def __init__( self.local_files_only = local_files_only self.meta_fields_to_embed = meta_fields_to_embed or [] self.embedding_separator = embedding_separator + self.model_kwargs = model_kwargs def to_dict(self) -> Dict[str, Any]: """ @@ -110,6 +113,7 @@ def to_dict(self) -> Dict[str, Any]: local_files_only=self.local_files_only, meta_fields_to_embed=self.meta_fields_to_embed, embedding_separator=self.embedding_separator, + model_kwargs=self.model_kwargs, ) def warm_up(self): @@ -122,6 +126,7 @@ def warm_up(self): cache_dir=self.cache_dir, threads=self.threads, local_files_only=self.local_files_only, + model_kwargs=self.model_kwargs, ) def _prepare_texts_to_embed(self, documents: List[Document]) -> List[str]: diff --git a/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/fastembed_sparse_text_embedder.py b/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/fastembed_sparse_text_embedder.py index 67348b2bd..2ebab35b4 100644 --- a/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/fastembed_sparse_text_embedder.py +++ b/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/fastembed_sparse_text_embedder.py @@ -35,6 +35,7 @@ def __init__( progress_bar: bool = True, parallel: Optional[int] = None, local_files_only: bool = False, + model_kwargs: Optional[Dict[str, Any]] = None, ): """ Create a FastembedSparseTextEmbedder component. @@ -50,6 +51,7 @@ def __init__( If 0, use all available cores. If None, don't use data-parallel processing, use default onnxruntime threading instead. :param local_files_only: If `True`, only use the model files in the `cache_dir`. + :param model_kwargs: Dictionary containing model parameters such as `k`, `b`, `avg_len`, `language`. """ self.model_name = model @@ -58,6 +60,7 @@ def __init__( self.progress_bar = progress_bar self.parallel = parallel self.local_files_only = local_files_only + self.model_kwargs = model_kwargs def to_dict(self) -> Dict[str, Any]: """ @@ -74,6 +77,7 @@ def to_dict(self) -> Dict[str, Any]: progress_bar=self.progress_bar, parallel=self.parallel, local_files_only=self.local_files_only, + model_kwargs=self.model_kwargs, ) def warm_up(self): @@ -86,6 +90,7 @@ def warm_up(self): cache_dir=self.cache_dir, threads=self.threads, local_files_only=self.local_files_only, + model_kwargs=self.model_kwargs, ) @component.output_types(sparse_embedding=SparseEmbedding) diff --git a/integrations/fastembed/tests/test_fastembed_backend.py b/integrations/fastembed/tests/test_fastembed_backend.py index 631d9f1e0..994a6f883 100644 --- a/integrations/fastembed/tests/test_fastembed_backend.py +++ b/integrations/fastembed/tests/test_fastembed_backend.py @@ -2,6 +2,7 @@ from haystack_integrations.components.embedders.fastembed.embedding_backend.fastembed_backend import ( _FastembedEmbeddingBackendFactory, + _FastembedSparseEmbeddingBackendFactory, ) @@ -44,3 +45,28 @@ def test_embedding_function_with_kwargs(mock_instructor): # noqa: ARG001 embedding_backend.model.embed.assert_called_once_with(data) # restore the factory stateTrue _FastembedEmbeddingBackendFactory._instances = {} + + +@patch("haystack_integrations.components.embedders.fastembed.embedding_backend.fastembed_backend.SparseTextEmbedding") +def test_model_kwargs_initialization(mock_instructor): + bm25_config = { + "k": 1.2, + "b": 0.75, + "avg_len": 300.0, + "language": "english", + "token_max_length": 40, + } + + # Invoke the backend factory with the BM25 configuration + _FastembedSparseEmbeddingBackendFactory.get_embedding_backend( + model_name="Qdrant/bm25", + model_kwargs=bm25_config, + ) + + # Check if SparseTextEmbedding was called with the correct arguments + mock_instructor.assert_called_once_with( + model_name="Qdrant/bm25", cache_dir=None, threads=None, local_files_only=False, **bm25_config + ) + + # Restore factory state after the test + _FastembedSparseEmbeddingBackendFactory._instances = {} diff --git a/integrations/fastembed/tests/test_fastembed_sparse_document_embedder.py b/integrations/fastembed/tests/test_fastembed_sparse_document_embedder.py index d3f2023b8..90e94908d 100644 --- a/integrations/fastembed/tests/test_fastembed_sparse_document_embedder.py +++ b/integrations/fastembed/tests/test_fastembed_sparse_document_embedder.py @@ -69,6 +69,7 @@ def test_to_dict(self): "local_files_only": False, "embedding_separator": "\n", "meta_fields_to_embed": [], + "model_kwargs": None, }, } @@ -100,6 +101,7 @@ def test_to_dict_with_custom_init_parameters(self): "local_files_only": True, "meta_fields_to_embed": ["test_field"], "embedding_separator": " | ", + "model_kwargs": None, }, } @@ -174,7 +176,11 @@ def test_warmup(self, mocked_factory): mocked_factory.get_embedding_backend.assert_not_called() embedder.warm_up() mocked_factory.get_embedding_backend.assert_called_once_with( - model_name="prithvida/Splade_PP_en_v1", cache_dir=None, threads=None, local_files_only=False + model_name="prithvida/Splade_PP_en_v1", + cache_dir=None, + threads=None, + local_files_only=False, + model_kwargs=None, ) @patch( @@ -275,6 +281,52 @@ def test_embed_metadata(self): parallel=None, ) + def test_init_with_model_kwargs_parameters(self): + """ + Test initialization of FastembedSparseDocumentEmbedder with model_kwargs parameters. + """ + bm25_config = { + "k": 1.2, + "b": 0.75, + "avg_len": 300.0, + "language": "english", + "token_max_length": 50, + } + + embedder = FastembedSparseDocumentEmbedder( + model="Qdrant/bm25", + model_kwargs=bm25_config, + ) + + assert embedder.model_kwargs == bm25_config + + @pytest.mark.integration + def test_run_with_model_kwargs(self): + """ + Integration test to check the embedding with model_kwargs parameters. + """ + bm42_config = { + "alpha": 0.2, + } + + embedder = FastembedSparseDocumentEmbedder( + model="Qdrant/bm42-all-minilm-l6-v2-attentions", + model_kwargs=bm42_config, + ) + embedder.warm_up() + + doc = Document(content="Example content using BM42") + + result = embedder.run(documents=[doc]) + embedding = result["documents"][0].sparse_embedding + embedding_dict = embedding.to_dict() + + assert isinstance(embedding, SparseEmbedding) + assert isinstance(embedding_dict["indices"], list) + assert isinstance(embedding_dict["values"], list) + assert isinstance(embedding_dict["indices"][0], int) + assert isinstance(embedding_dict["values"][0], float) + @pytest.mark.integration def test_run(self): embedder = FastembedSparseDocumentEmbedder( diff --git a/integrations/fastembed/tests/test_fastembed_sparse_text_embedder.py b/integrations/fastembed/tests/test_fastembed_sparse_text_embedder.py index 7e9197493..4f438fd15 100644 --- a/integrations/fastembed/tests/test_fastembed_sparse_text_embedder.py +++ b/integrations/fastembed/tests/test_fastembed_sparse_text_embedder.py @@ -54,6 +54,7 @@ def test_to_dict(self): "progress_bar": True, "parallel": None, "local_files_only": False, + "model_kwargs": None, }, } @@ -79,6 +80,7 @@ def test_to_dict_with_custom_init_parameters(self): "progress_bar": False, "parallel": 1, "local_files_only": True, + "model_kwargs": None, }, } @@ -135,7 +137,11 @@ def test_warmup(self, mocked_factory): mocked_factory.get_embedding_backend.assert_not_called() embedder.warm_up() mocked_factory.get_embedding_backend.assert_called_once_with( - model_name="prithvida/Splade_PP_en_v1", cache_dir=None, threads=None, local_files_only=False + model_name="prithvida/Splade_PP_en_v1", + cache_dir=None, + threads=None, + local_files_only=False, + model_kwargs=None, ) @patch( @@ -195,6 +201,54 @@ def test_run_wrong_incorrect_format(self): with pytest.raises(TypeError, match="FastembedSparseTextEmbedder expects a string as input"): embedder.run(text=list_integers_input) + def test_init_with_model_kwargs_parameters(self): + """ + Test initialization of FastembedSparseTextEmbedder with model_kwargs parameters. + """ + bm25_config = { + "k": 1.2, + "b": 0.75, + "avg_len": 300.0, + "language": "english", + "token_max_length": 50, + } + + embedder = FastembedSparseTextEmbedder( + model="Qdrant/bm25", + model_kwargs=bm25_config, + ) + + assert embedder.model_kwargs == bm25_config + + @pytest.mark.integration + def test_run_with_model_kwargs(self): + """ + Integration test to check the embedding with model_kwargs parameters. + """ + bm25_config = { + "k": 1.2, + "b": 0.75, + "avg_len": 256.0, + } + + embedder = FastembedSparseTextEmbedder( + model="Qdrant/bm25", + model_kwargs=bm25_config, + ) + embedder.warm_up() + + text = "Example content using BM25" + + result = embedder.run(text=text) + embedding = result["sparse_embedding"] + embedding_dict = embedding.to_dict() + + assert isinstance(embedding, SparseEmbedding) + assert isinstance(embedding_dict["indices"], list) + assert isinstance(embedding_dict["values"], list) + assert isinstance(embedding_dict["indices"][0], int) + assert isinstance(embedding_dict["values"][0], float) + @pytest.mark.integration def test_run(self): embedder = FastembedSparseTextEmbedder( From 9438634f2082bc8c7ebd518f319512ed4fbe9882 Mon Sep 17 00:00:00 2001 From: HaystackBot Date: Mon, 7 Oct 2024 08:58:06 +0000 Subject: [PATCH 027/229] Update the changelog --- integrations/fastembed/CHANGELOG.md | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/integrations/fastembed/CHANGELOG.md b/integrations/fastembed/CHANGELOG.md index 9ae3da929..b5c194d8b 100644 --- a/integrations/fastembed/CHANGELOG.md +++ b/integrations/fastembed/CHANGELOG.md @@ -1,11 +1,20 @@ # Changelog -## [unreleased] +## [integrations/fastembed-v1.3.0] - 2024-10-07 + +### 🚀 Features + +- Introduce `model_kwargs` in Sparse Embedders (can be used for BM25 parameters) (#1126) + +### 🧪 Testing + +- Do not retry tests in `hatch run test` command (#954) ### ⚙️ Miscellaneous Tasks - Retry tests to reduce flakyness (#836) - Update ruff invocation to include check parameter (#853) +- Update ruff linting scripts and settings (#1105) ### Fix From 9f1fb944a7225de9d09cf09650199a609eccd664 Mon Sep 17 00:00:00 2001 From: emso-c <32599085+emso-c@users.noreply.github.com> Date: Fri, 11 Oct 2024 17:51:17 +0300 Subject: [PATCH 028/229] feat: add `keep_alive` parameter to Ollama Generators (#1131) * feat: add keep_alive parameter to Ollama integrations * style: run linter * fix: serialize keep_alive parameters * test: include keep_alive parameter in tests * docs: add keep_alive usage to the docstring * style: I keep forgetting to lint * style: update docs * small fixes to docstrings --------- Co-authored-by: anakin87 --- .../generators/ollama/chat/chat_generator.py | 17 +++++++++++++++-- .../components/generators/ollama/generator.py | 17 +++++++++++++++-- .../ollama/tests/test_chat_generator.py | 7 +++++++ integrations/ollama/tests/test_generator.py | 8 ++++++++ 4 files changed, 45 insertions(+), 4 deletions(-) diff --git a/integrations/ollama/src/haystack_integrations/components/generators/ollama/chat/chat_generator.py b/integrations/ollama/src/haystack_integrations/components/generators/ollama/chat/chat_generator.py index 9502a187e..558fd593e 100644 --- a/integrations/ollama/src/haystack_integrations/components/generators/ollama/chat/chat_generator.py +++ b/integrations/ollama/src/haystack_integrations/components/generators/ollama/chat/chat_generator.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, Dict, List, Optional +from typing import Any, Callable, Dict, List, Optional, Union from haystack import component, default_from_dict, default_to_dict from haystack.dataclasses import ChatMessage, StreamingChunk @@ -38,6 +38,7 @@ def __init__( url: str = "http://localhost:11434", generation_kwargs: Optional[Dict[str, Any]] = None, timeout: int = 120, + keep_alive: Optional[Union[float, str]] = None, streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, ): """ @@ -54,12 +55,21 @@ def __init__( :param streaming_callback: A callback function that is called when a new token is received from the stream. The callback function accepts StreamingChunk as an argument. + :param keep_alive: + The option that controls how long the model will stay loaded into memory following the request. + If not set, it will use the default value from the Ollama (5 minutes). + The value can be set to: + - a duration string (such as "10m" or "24h") + - a number in seconds (such as 3600) + - any negative number which will keep the model loaded in memory (e.g. -1 or "-1m") + - '0' which will unload the model immediately after generating a response. """ self.timeout = timeout self.generation_kwargs = generation_kwargs or {} self.url = url self.model = model + self.keep_alive = keep_alive self.streaming_callback = streaming_callback self._client = Client(host=self.url, timeout=self.timeout) @@ -76,6 +86,7 @@ def to_dict(self) -> Dict[str, Any]: self, model=self.model, url=self.url, + keep_alive=self.keep_alive, generation_kwargs=self.generation_kwargs, timeout=self.timeout, streaming_callback=callback_name, @@ -165,7 +176,9 @@ def run( stream = self.streaming_callback is not None messages = [self._message_to_dict(message) for message in messages] - response = self._client.chat(model=self.model, messages=messages, stream=stream, options=generation_kwargs) + response = self._client.chat( + model=self.model, messages=messages, stream=stream, keep_alive=self.keep_alive, options=generation_kwargs + ) if stream: chunks: List[StreamingChunk] = self._handle_streaming_response(response) diff --git a/integrations/ollama/src/haystack_integrations/components/generators/ollama/generator.py b/integrations/ollama/src/haystack_integrations/components/generators/ollama/generator.py index d92932c3e..058948e8a 100644 --- a/integrations/ollama/src/haystack_integrations/components/generators/ollama/generator.py +++ b/integrations/ollama/src/haystack_integrations/components/generators/ollama/generator.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, Dict, List, Optional +from typing import Any, Callable, Dict, List, Optional, Union from haystack import component, default_from_dict, default_to_dict from haystack.dataclasses import StreamingChunk @@ -36,6 +36,7 @@ def __init__( template: Optional[str] = None, raw: bool = False, timeout: int = 120, + keep_alive: Optional[Union[float, str]] = None, streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, ): """ @@ -59,6 +60,14 @@ def __init__( :param streaming_callback: A callback function that is called when a new token is received from the stream. The callback function accepts StreamingChunk as an argument. + :param keep_alive: + The option that controls how long the model will stay loaded into memory following the request. + If not set, it will use the default value from the Ollama (5 minutes). + The value can be set to: + - a duration string (such as "10m" or "24h") + - a number in seconds (such as 3600) + - any negative number which will keep the model loaded in memory (e.g. -1 or "-1m") + - '0' which will unload the model immediately after generating a response. """ self.timeout = timeout self.raw = raw @@ -66,6 +75,7 @@ def __init__( self.system_prompt = system_prompt self.model = model self.url = url + self.keep_alive = keep_alive self.generation_kwargs = generation_kwargs or {} self.streaming_callback = streaming_callback @@ -87,6 +97,7 @@ def to_dict(self) -> Dict[str, Any]: system_prompt=self.system_prompt, model=self.model, url=self.url, + keep_alive=self.keep_alive, generation_kwargs=self.generation_kwargs, streaming_callback=callback_name, ) @@ -172,7 +183,9 @@ def run( stream = self.streaming_callback is not None - response = self._client.generate(model=self.model, prompt=prompt, stream=stream, options=generation_kwargs) + response = self._client.generate( + model=self.model, prompt=prompt, stream=stream, keep_alive=self.keep_alive, options=generation_kwargs + ) if stream: chunks: List[StreamingChunk] = self._handle_streaming_response(response) diff --git a/integrations/ollama/tests/test_chat_generator.py b/integrations/ollama/tests/test_chat_generator.py index a46758df3..5ac9289aa 100644 --- a/integrations/ollama/tests/test_chat_generator.py +++ b/integrations/ollama/tests/test_chat_generator.py @@ -26,12 +26,14 @@ def test_init_default(self): assert component.url == "http://localhost:11434" assert component.generation_kwargs == {} assert component.timeout == 120 + assert component.keep_alive is None def test_init(self): component = OllamaChatGenerator( model="llama2", url="http://my-custom-endpoint:11434", generation_kwargs={"temperature": 0.5}, + keep_alive="10m", timeout=5, ) @@ -39,6 +41,7 @@ def test_init(self): assert component.url == "http://my-custom-endpoint:11434" assert component.generation_kwargs == {"temperature": 0.5} assert component.timeout == 5 + assert component.keep_alive == "10m" def test_to_dict(self): component = OllamaChatGenerator( @@ -46,6 +49,7 @@ def test_to_dict(self): streaming_callback=print_streaming_chunk, url="custom_url", generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"}, + keep_alive="5m", ) data = component.to_dict() assert data == { @@ -53,6 +57,7 @@ def test_to_dict(self): "init_parameters": { "timeout": 120, "model": "llama2", + "keep_alive": "5m", "url": "custom_url", "streaming_callback": "haystack.components.generators.utils.print_streaming_chunk", "generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"}, @@ -66,6 +71,7 @@ def test_from_dict(self): "timeout": 120, "model": "llama2", "url": "custom_url", + "keep_alive": "5m", "streaming_callback": "haystack.components.generators.utils.print_streaming_chunk", "generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"}, }, @@ -75,6 +81,7 @@ def test_from_dict(self): assert component.streaming_callback is print_streaming_chunk assert component.url == "custom_url" assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"} + assert component.keep_alive == "5m" def test_build_message_from_ollama_response(self): model = "some_model" diff --git a/integrations/ollama/tests/test_generator.py b/integrations/ollama/tests/test_generator.py index c4c6906db..b02370234 100644 --- a/integrations/ollama/tests/test_generator.py +++ b/integrations/ollama/tests/test_generator.py @@ -45,6 +45,7 @@ def test_init_default(self): assert component.template is None assert component.raw is False assert component.timeout == 120 + assert component.keep_alive is None assert component.streaming_callback is None def test_init(self): @@ -57,6 +58,7 @@ def callback(x: StreamingChunk): generation_kwargs={"temperature": 0.5}, system_prompt="You are Luigi from Super Mario Bros.", timeout=5, + keep_alive="10m", streaming_callback=callback, ) assert component.model == "llama2" @@ -66,6 +68,7 @@ def callback(x: StreamingChunk): assert component.template is None assert component.raw is False assert component.timeout == 5 + assert component.keep_alive == "10m" assert component.streaming_callback == callback component = OllamaGenerator() @@ -80,6 +83,7 @@ def callback(x: StreamingChunk): "model": "orca-mini", "url": "http://localhost:11434", "streaming_callback": None, + "keep_alive": None, "generation_kwargs": {}, }, } @@ -89,6 +93,7 @@ def test_to_dict_with_parameters(self): model="llama2", streaming_callback=print_streaming_chunk, url="going_to_51_pegasi_b_for_weekend", + keep_alive="10m", generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"}, ) data = component.to_dict() @@ -100,6 +105,7 @@ def test_to_dict_with_parameters(self): "template": None, "system_prompt": None, "model": "llama2", + "keep_alive": "10m", "url": "going_to_51_pegasi_b_for_weekend", "streaming_callback": "haystack.components.generators.utils.print_streaming_chunk", "generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"}, @@ -115,6 +121,7 @@ def test_from_dict(self): "template": None, "system_prompt": None, "model": "llama2", + "keep_alive": "5m", "url": "going_to_51_pegasi_b_for_weekend", "streaming_callback": "haystack.components.generators.utils.print_streaming_chunk", "generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"}, @@ -125,6 +132,7 @@ def test_from_dict(self): assert component.streaming_callback is print_streaming_chunk assert component.url == "going_to_51_pegasi_b_for_weekend" assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"} + assert component.keep_alive == "5m" @pytest.mark.integration def test_ollama_generator_run_streaming(self): From 518cf27a8e765306bef4f6e0c7e1086169a11d4b Mon Sep 17 00:00:00 2001 From: HaystackBot Date: Fri, 11 Oct 2024 14:58:03 +0000 Subject: [PATCH 029/229] Update the changelog --- integrations/ollama/CHANGELOG.md | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/integrations/ollama/CHANGELOG.md b/integrations/ollama/CHANGELOG.md index 5da725e87..55c6aa7b7 100644 --- a/integrations/ollama/CHANGELOG.md +++ b/integrations/ollama/CHANGELOG.md @@ -1,5 +1,15 @@ # Changelog +## [integrations/ollama-v1.1.0] - 2024-10-11 + +### 🚀 Features + +- Add `keep_alive` parameter to Ollama Generators (#1131) + +### ⚙️ Miscellaneous Tasks + +- Update ruff linting scripts and settings (#1105) + ## [integrations/ollama-v1.0.1] - 2024-09-26 ### 🐛 Bug Fixes From a223e6f0705ec0d788205ae6c996919d5d0258d1 Mon Sep 17 00:00:00 2001 From: Kane Norman <51185594+kanenorman@users.noreply.github.com> Date: Tue, 15 Oct 2024 04:02:44 -0500 Subject: [PATCH 030/229] docs: explain different connection string formats in the docstring (#1132) * feat: specify individual connection parameters * style: fix excessive line length * test: correct env variables * docs: update PgvectorDocumentStore docstring Co-authored-by: David S. Batista * style: reorder parameters * style: reorder parameters * revert: remove individual connection params * docs: update postgres connection string docs --------- Co-authored-by: David S. Batista --- .../document_stores/pgvector/document_store.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/integrations/pgvector/src/haystack_integrations/document_stores/pgvector/document_store.py b/integrations/pgvector/src/haystack_integrations/document_stores/pgvector/document_store.py index a02c46200..1b1333f5c 100644 --- a/integrations/pgvector/src/haystack_integrations/document_stores/pgvector/document_store.py +++ b/integrations/pgvector/src/haystack_integrations/document_stores/pgvector/document_store.py @@ -96,7 +96,11 @@ def __init__( A specific table to store Haystack documents will be created if it doesn't exist yet. :param connection_string: The connection string to use to connect to the PostgreSQL database, defined as an - environment variable, e.g.: `PG_CONN_STR="postgresql://USER:PASSWORD@HOST:PORT/DB_NAME"` + environment variable. It can be provided in either URI format + e.g.: `PG_CONN_STR="postgresql://USER:PASSWORD@HOST:PORT/DB_NAME"`, or keyword/value format + e.g.: `PG_CONN_STR="host=HOST port=PORT dbname=DBNAME user=USER password=PASSWORD"` + See [PostgreSQL Documentation](https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-CONNSTRING) + for more details. :param table_name: The name of the table to use to store Haystack documents. :param language: The language to be used to parse query and document content in keyword retrieval. To see the list of available languages, you can run the following SQL query in your PostgreSQL database: From a6bb28f9460b16716437ab9e619a2d2328bca6f1 Mon Sep 17 00:00:00 2001 From: emso-c <32599085+emso-c@users.noreply.github.com> Date: Tue, 15 Oct 2024 14:37:12 +0300 Subject: [PATCH 031/229] docs: update relative issue paths in CONTRIBUTING.md (#1128) * docs: update relative issue paths This commit fixes references to the GitHub's issues page. --- CONTRIBUTING.md | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 32f8ca677..e0ba3d036 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -48,14 +48,14 @@ By participating, you are expected to uphold this code. Please report unacceptab ## I Have a Question > [!TIP] -> If you want to ask a question, we assume that you have read the available [Documentation](https://docs.haystack.deepset.ai/v2.0/docs/intro). +> If you want to ask a question, we assume that you have read the available [documentation](https://docs.haystack.deepset.ai/docs/intro). -Before you ask a question, it is best to search for existing [Issues](/issues) that might help you. In case you have +Before you ask a question, it is best to search for existing [issues](/../../issues) that might help you. In case you have found a suitable issue and still need clarification, you can write your question in this issue. It is also advisable to search the internet for answers first. If you then still feel the need to ask a question and need clarification, you can use one of our -[Community Channels](https://haystack.deepset.ai/community), Discord in particular is often very helpful. +[community channels](https://haystack.deepset.ai/community). Discord in particular is often very helpful. ## Reporting Bugs @@ -67,8 +67,8 @@ investigate carefully, collect information and describe the issue in detail in y following steps in advance to help us fix any potential bug as fast as possible. - Make sure that you are using the latest version. -- Determine if your bug is really a bug and not an error on your side e.g. using incompatible environment components/versions (Make sure that you have read the [documentation](https://docs.haystack.deepset.ai/v2.0/docs/intro). If you are looking for support, you might want to check [this section](#i-have-a-question)). -- To see if other users have experienced (and potentially already solved) the same issue you are having, check if there is not already a bug report existing for your bug or error in the [bug tracker](/issues). +- Determine if your bug is really a bug and not an error on your side e.g. using incompatible environment components/versions (Make sure that you have read the [documentation](https://docs.haystack.deepset.ai/docs/intro). If you are looking for support, you might want to check [this section](#i-have-a-question)). +- To see if other users have experienced (and potentially already solved) the same issue you are having, check if there is not already a bug report existing for your bug or error in the [bug tracker](/../../issues?labels=bug). - Also make sure to search the internet (including Stack Overflow) to see if users outside of the GitHub community have discussed the issue. - Collect information about the bug: - OS, Platform and Version (Windows, Linux, macOS, x86, ARM) @@ -85,7 +85,7 @@ following steps in advance to help us fix any potential bug as fast as possible. We use GitHub issues to track bugs and errors. If you run into an issue with the project: -- Open an [Issue of type Bug Report](/issues/new?assignees=&labels=bug&projects=&template=bug_report.md&title=). +- Open an [issue of type Bug Report](/../../issues/new?assignees=&labels=bug&projects=&template=bug_report.md&title=). - Explain the behavior you would expect and the actual behavior. - Please provide as much context as possible and describe the *reproduction steps* that someone else can follow to recreate the issue on their own. This usually includes your code. For good bug reports you should isolate the problem and create a reduced test case. - Provide the information you collected in the previous section. @@ -94,7 +94,7 @@ Once it's filed: - The project team will label the issue accordingly. - A team member will try to reproduce the issue with your provided steps. If there are no reproduction steps or no obvious way to reproduce the issue, the team will ask you for those steps. -- If the team is able to reproduce the issue, the issue will scheduled for a fix, or left to be [implemented by someone](#your-first-code-contribution). +- If the team can reproduce the issue, it will either be scheduled for a fix or made available for [community contribution](#contribute-code). ## Suggesting Enhancements @@ -106,14 +106,14 @@ to existing ones. Following these guidelines will help maintainers and the commu ### Before Submitting an Enhancement - Make sure that you are using the latest version. -- Read the [documentation](https://docs.haystack.deepset.ai/v2.0/docs/intro) carefully and find out if the functionality is already covered, maybe by an individual configuration. -- Perform a [search](/issues) to see if the enhancement has already been suggested. If it has, add a comment to the existing issue instead of opening a new one. +- Read the [documentation](https://docs.haystack.deepset.ai/docs/intro) carefully and find out if the functionality is already covered, maybe by an individual configuration. +- Perform a [search](/../../issues) to see if the enhancement has already been suggested. If it has, add a comment to the existing issue instead of opening a new one. - Find out whether your idea fits with the scope and aims of the project. It's up to you to make a strong case to convince the project's developers of the merits of this feature. Keep in mind that we want features that will be useful to the majority of our users and not just a small subset. If you're just targeting a minority of users, consider writing and distributing the integration on your own. ### How Do I Submit a Good Enhancement Suggestion? -Enhancement suggestions are tracked as GitHub issues of type [Feature request for existing integrations](/issues/new?assignees=&labels=feature+request&projects=&template=feature-request-for-existing-integrations.md&title=). +Enhancement suggestions are tracked as GitHub issues of type [Feature request for existing integrations](/../../issues/new?assignees=&labels=feature+request&projects=&template=feature-request-for-existing-integrations.md&title=). - Use a **clear and descriptive title** for the issue to identify the suggestion. - Fill the issue following the template @@ -129,8 +129,8 @@ Enhancement suggestions are tracked as GitHub issues of type [Feature request fo If this is your first contribution, a good starting point is looking for an open issue that's marked with the label ["good first issue"](https://github.com/deepset-ai/haystack-core-integrations/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22). The core contributors periodically mark certain issues as good for first-time contributors. Those issues are usually -limited in scope, easy fixable and low priority, so there is absolutely no reason why you should not try fixing them, -it's a good excuse to start looking into the project and a safe space for experimenting failure: if you don't get the +limited in scope, easy fixable and low priority, so there is absolutely no reason why you should not try fixing them. +It's also a good excuse to start looking into the project and a safe space for experimenting failure: if you don't get the grasp of something, pick another one! ### Setting up your development environment @@ -279,7 +279,7 @@ The Python API docs detail the source code: classes, functions, and parameters t This type of documentation is extracted from the source code itself, and contributors should pay attention when they change the code to also change relevant comments and docstrings. This type of documentation is mostly useful to developers, but it can be handy for users at times. You can browse it on the dedicated section in the -[documentation website](https://docs.haystack.deepset.ai/v2.0/reference/integrations-chroma). +[documentation website](https://docs.haystack.deepset.ai/reference/integrations-chroma). We use `pydoc-markdown` to convert docstrings into properly formatted Markdown files, and while the CI takes care of generating and publishing the updated documentation at every merge on the `main` branch, you can generate the docs From a722a22761f3671bfc299467f28769b1eb789f74 Mon Sep 17 00:00:00 2001 From: Stefano Fiorucci Date: Tue, 15 Oct 2024 14:49:16 +0200 Subject: [PATCH 032/229] fix: make sure that streaming works with function calls - (drop python3.8) (#1137) * fix streaming w function calls - drop python 3.8 * keep metadata var * fmt --- integrations/google_ai/pyproject.toml | 3 +-- .../generators/google_ai/chat/gemini.py | 19 ++++++++++--------- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/integrations/google_ai/pyproject.toml b/integrations/google_ai/pyproject.toml index d06e0a53f..88fbcd61c 100644 --- a/integrations/google_ai/pyproject.toml +++ b/integrations/google_ai/pyproject.toml @@ -7,7 +7,7 @@ name = "google-ai-haystack" dynamic = ["version"] description = 'Use models like Gemini via Makersuite' readme = "README.md" -requires-python = ">=3.8" +requires-python = ">=3.9" license = "Apache-2.0" keywords = [] authors = [{ name = "deepset GmbH", email = "info@deepset.ai" }] @@ -15,7 +15,6 @@ classifiers = [ "License :: OSI Approved :: Apache Software License", "Development Status :: 4 - Beta", "Programming Language :: Python", - "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", diff --git a/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/chat/gemini.py b/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/chat/gemini.py index 56c84968b..8efa8cda7 100644 --- a/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/chat/gemini.py +++ b/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/chat/gemini.py @@ -347,20 +347,21 @@ def _get_stream_response( replies: List[ChatMessage] = [] for chunk in stream: content: Union[str, Dict[str, Any]] = "" - metadata = chunk.to_dict() # we store whole chunk as metadata in streaming calls - for candidate in chunk.candidates: - for part in candidate.content.parts: - if part.text != "": - content = part.text + dict_chunk = chunk.to_dict() + metadata = dict(dict_chunk) # we copy and store the whole chunk as metadata in streaming calls + for candidate in dict_chunk["candidates"]: + for part in candidate["content"]["parts"]: + if "text" in part and part["text"] != "": + content = part["text"] replies.append(ChatMessage(content=content, role=ChatRole.ASSISTANT, meta=metadata, name=None)) - elif part.function_call is not None: - metadata["function_call"] = part.function_call - content = dict(part.function_call.args.items()) + elif "function_call" in part and len(part["function_call"]) > 0: + metadata["function_call"] = part["function_call"] + content = part["function_call"]["args"] replies.append( ChatMessage( content=content, role=ChatRole.ASSISTANT, - name=part.function_call.name, + name=part["function_call"]["name"], meta=metadata, ) ) From a04dae99c9d6e38e63a662a2f3ba2e9edc8d0f1b Mon Sep 17 00:00:00 2001 From: HaystackBot Date: Tue, 15 Oct 2024 12:51:27 +0000 Subject: [PATCH 033/229] Update the changelog --- integrations/google_ai/CHANGELOG.md | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/integrations/google_ai/CHANGELOG.md b/integrations/google_ai/CHANGELOG.md index 0fc7ce0ab..3f3ecaf79 100644 --- a/integrations/google_ai/CHANGELOG.md +++ b/integrations/google_ai/CHANGELOG.md @@ -1,10 +1,16 @@ # Changelog -## [unreleased] +## [integrations/google_ai-v2.0.1] - 2024-10-15 + +### 🚀 Features + +- Add chatrole tests and meta for GeminiChatGenerators (#1090) ### 🐛 Bug Fixes - Remove the use of deprecated gemini models (#1032) +- Chat roles for model responses in chat generators (#1030) +- Make sure that streaming works with function calls - (drop python3.8) (#1137) ### 🧪 Testing @@ -14,6 +20,7 @@ - Retry tests to reduce flakyness (#836) - Update ruff invocation to include check parameter (#853) +- Update ruff linting scripts and settings (#1105) ### Docs From 9a3c2e060edee9298536dac68e875f10883469fd Mon Sep 17 00:00:00 2001 From: tstadel <60758086+tstadel@users.noreply.github.com> Date: Wed, 16 Oct 2024 09:44:32 +0200 Subject: [PATCH 034/229] fix: avoid bedrock read timeout (add boto3_config param) (#1135) * fix: avoid bedrock read timeout * fix lint * fix test * add from_dict test --- .../components/generators/amazon_bedrock/generator.py | 10 +++++++++- integrations/amazon_bedrock/tests/test_generator.py | 5 +++++ 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/generator.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/generator.py index 1edde3526..193332009 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/generator.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/generator.py @@ -3,6 +3,7 @@ import re from typing import Any, Callable, ClassVar, Dict, List, Optional, Type +from botocore.config import Config from botocore.exceptions import ClientError from haystack import component, default_from_dict, default_to_dict from haystack.dataclasses import StreamingChunk @@ -87,6 +88,7 @@ def __init__( max_length: Optional[int] = 100, truncate: Optional[bool] = True, streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, + boto3_config: Optional[Dict[str, Any]] = None, **kwargs, ): """ @@ -102,6 +104,7 @@ def __init__( :param truncate: Whether to truncate the prompt or not. :param streaming_callback: A callback function that is called when a new token is received from the stream. The callback function accepts StreamingChunk as an argument. + :param boto3_config: The configuration for the boto3 client. :param kwargs: Additional keyword arguments to be passed to the model. These arguments are specific to the model. You can find them in the model's documentation. :raises ValueError: If the model name is empty or None. @@ -120,6 +123,7 @@ def __init__( self.aws_region_name = aws_region_name self.aws_profile_name = aws_profile_name self.streaming_callback = streaming_callback + self.boto3_config = boto3_config self.kwargs = kwargs def resolve_secret(secret: Optional[Secret]) -> Optional[str]: @@ -133,7 +137,10 @@ def resolve_secret(secret: Optional[Secret]) -> Optional[str]: aws_region_name=resolve_secret(aws_region_name), aws_profile_name=resolve_secret(aws_profile_name), ) - self.client = session.client("bedrock-runtime") + config: Optional[Config] = None + if self.boto3_config: + config = Config(**self.boto3_config) + self.client = session.client("bedrock-runtime", config=config) except Exception as exception: msg = ( "Could not connect to Amazon Bedrock. Make sure the AWS environment is configured correctly. " @@ -273,6 +280,7 @@ def to_dict(self) -> Dict[str, Any]: max_length=self.max_length, truncate=self.truncate, streaming_callback=callback_name, + boto3_config=self.boto3_config, **self.kwargs, ) diff --git a/integrations/amazon_bedrock/tests/test_generator.py b/integrations/amazon_bedrock/tests/test_generator.py index 61ae9d6b4..2ccd5a3fa 100644 --- a/integrations/amazon_bedrock/tests/test_generator.py +++ b/integrations/amazon_bedrock/tests/test_generator.py @@ -36,6 +36,7 @@ def test_to_dict(mock_boto3_session): "truncate": False, "temperature": 10, "streaming_callback": None, + "boto3_config": None, }, } @@ -57,12 +58,16 @@ def test_from_dict(mock_boto3_session): "aws_profile_name": {"type": "env_var", "env_vars": ["AWS_PROFILE"], "strict": False}, "model": "anthropic.claude-v2", "max_length": 99, + "boto3_config": { + "read_timeout": 1000, + }, }, } ) assert generator.max_length == 99 assert generator.model == "anthropic.claude-v2" + assert generator.boto3_config == {"read_timeout": 1000} def test_default_constructor(mock_boto3_session, set_env_variables): From f95dd06614ed3089fa1fe4a0ce780b52949dc8e3 Mon Sep 17 00:00:00 2001 From: HaystackBot Date: Wed, 16 Oct 2024 07:47:22 +0000 Subject: [PATCH 035/229] Update the changelog --- integrations/amazon_bedrock/CHANGELOG.md | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/integrations/amazon_bedrock/CHANGELOG.md b/integrations/amazon_bedrock/CHANGELOG.md index e7b0db667..43ab788d3 100644 --- a/integrations/amazon_bedrock/CHANGELOG.md +++ b/integrations/amazon_bedrock/CHANGELOG.md @@ -1,5 +1,11 @@ # Changelog +## [integrations/amazon_bedrock-v1.0.4] - 2024-10-16 + +### 🐛 Bug Fixes + +- Avoid bedrock read timeout (add boto3_config param) (#1135) + ## [integrations/amazon_bedrock-v1.0.3] - 2024-10-04 ### 🐛 Bug Fixes From ac0e4c2f8c8d0dce7a32e8e3a3fe74362b0686dd Mon Sep 17 00:00:00 2001 From: Abraham Yusuf Date: Thu, 17 Oct 2024 10:55:01 +0200 Subject: [PATCH 036/229] feat: add prefixes to supported model patterns to allow cross region model ids (#1127) * feat: add prefixes to supported model patterns to allow cross region model ids --- .../amazon_bedrock/chat/chat_generator.py | 6 +++--- .../generators/amazon_bedrock/generator.py | 14 +++++++------- .../amazon_bedrock/tests/test_chat_generator.py | 6 ++++-- .../amazon_bedrock/tests/test_generator.py | 7 ++++++- 4 files changed, 20 insertions(+), 13 deletions(-) diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py index e1732646a..6bb3cc301 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py @@ -58,9 +58,9 @@ class AmazonBedrockChatGenerator: """ SUPPORTED_MODEL_PATTERNS: ClassVar[Dict[str, Type[BedrockModelChatAdapter]]] = { - r"(.+\.)?anthropic.claude.*": AnthropicClaudeChatAdapter, - r"meta.llama2.*": MetaLlama2ChatAdapter, - r"mistral.*": MistralChatAdapter, + r"([a-z]{2}\.)?anthropic.claude.*": AnthropicClaudeChatAdapter, + r"([a-z]{2}\.)?meta.llama2.*": MetaLlama2ChatAdapter, + r"([a-z]{2}\.)?mistral.*": MistralChatAdapter, } def __init__( diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/generator.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/generator.py index 193332009..c6c814de4 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/generator.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/generator.py @@ -66,13 +66,13 @@ class AmazonBedrockGenerator: """ SUPPORTED_MODEL_PATTERNS: ClassVar[Dict[str, Type[BedrockModelAdapter]]] = { - r"amazon.titan-text.*": AmazonTitanAdapter, - r"ai21.j2.*": AI21LabsJurassic2Adapter, - r"cohere.command-[^r].*": CohereCommandAdapter, - r"cohere.command-r.*": CohereCommandRAdapter, - r"(.+\.)?anthropic.claude.*": AnthropicClaudeAdapter, - r"meta.llama.*": MetaLlamaAdapter, - r"mistral.*": MistralAdapter, + r"([a-z]{2}\.)?amazon.titan-text.*": AmazonTitanAdapter, + r"([a-z]{2}\.)?ai21.j2.*": AI21LabsJurassic2Adapter, + r"([a-z]{2}\.)?cohere.command-[^r].*": CohereCommandAdapter, + r"([a-z]{2}\.)?cohere.command-r.*": CohereCommandRAdapter, + r"([a-z]{2}\.)?anthropic.claude.*": AnthropicClaudeAdapter, + r"([a-z]{2}\.)?meta.llama.*": MetaLlamaAdapter, + r"([a-z]{2}\.)?mistral.*": MistralAdapter, } def __init__( diff --git a/integrations/amazon_bedrock/tests/test_chat_generator.py b/integrations/amazon_bedrock/tests/test_chat_generator.py index 49abc0979..571e03eb2 100644 --- a/integrations/amazon_bedrock/tests/test_chat_generator.py +++ b/integrations/amazon_bedrock/tests/test_chat_generator.py @@ -243,7 +243,7 @@ def test_long_prompt_is_not_truncated_when_truncate_false(mock_boto3_session): generator.run(messages=messages) # Ensure _ensure_token_limit was not called - mock_ensure_token_limit.assert_not_called(), + mock_ensure_token_limit.assert_not_called() # Check the prompt passed to prepare_body generator.model_adapter.prepare_body.assert_called_with(messages=messages, stop_words=[], stream=False) @@ -261,6 +261,9 @@ def test_long_prompt_is_not_truncated_when_truncate_false(mock_boto3_session): ("meta.llama2-13b-chat-v1", MetaLlama2ChatAdapter), ("meta.llama2-70b-chat-v1", MetaLlama2ChatAdapter), ("meta.llama2-130b-v5", MetaLlama2ChatAdapter), # artificial + ("us.meta.llama2-13b-chat-v1", MetaLlama2ChatAdapter), # cross-region inference + ("eu.meta.llama2-70b-chat-v1", MetaLlama2ChatAdapter), # cross-region inference + ("de.meta.llama2-130b-v5", MetaLlama2ChatAdapter), # cross-region inference ("unknown_model", None), ], ) @@ -517,7 +520,6 @@ def test_get_responses(self) -> None: @pytest.mark.parametrize("model_name", MODELS_TO_TEST) @pytest.mark.integration def test_default_inference_params(self, model_name, chat_messages): - client = AmazonBedrockChatGenerator(model=model_name) response = client.run(chat_messages) diff --git a/integrations/amazon_bedrock/tests/test_generator.py b/integrations/amazon_bedrock/tests/test_generator.py index 2ccd5a3fa..79246b4aa 100644 --- a/integrations/amazon_bedrock/tests/test_generator.py +++ b/integrations/amazon_bedrock/tests/test_generator.py @@ -225,7 +225,7 @@ def test_long_prompt_is_not_truncated_when_truncate_false(mock_boto3_session): generator.run(prompt=long_prompt_text) # Ensure _ensure_token_limit was not called - mock_ensure_token_limit.assert_not_called(), + mock_ensure_token_limit.assert_not_called() # Check the prompt passed to prepare_body generator.model_adapter.prepare_body.assert_called_with(prompt=long_prompt_text, stream=False) @@ -251,10 +251,13 @@ def test_long_prompt_is_not_truncated_when_truncate_false(mock_boto3_session): ("ai21.j2-mega-v5", AI21LabsJurassic2Adapter), # artificial ("amazon.titan-text-lite-v1", AmazonTitanAdapter), ("amazon.titan-text-express-v1", AmazonTitanAdapter), + ("us.amazon.titan-text-express-v1", AmazonTitanAdapter), # cross-region inference ("amazon.titan-text-agile-v1", AmazonTitanAdapter), ("amazon.titan-text-lightning-v8", AmazonTitanAdapter), # artificial ("meta.llama2-13b-chat-v1", MetaLlamaAdapter), ("meta.llama2-70b-chat-v1", MetaLlamaAdapter), + ("eu.meta.llama2-13b-chat-v1", MetaLlamaAdapter), # cross-region inference + ("us.meta.llama2-70b-chat-v1", MetaLlamaAdapter), # cross-region inference ("meta.llama2-130b-v5", MetaLlamaAdapter), # artificial ("meta.llama3-8b-instruct-v1:0", MetaLlamaAdapter), ("meta.llama3-70b-instruct-v1:0", MetaLlamaAdapter), @@ -262,6 +265,8 @@ def test_long_prompt_is_not_truncated_when_truncate_false(mock_boto3_session): ("mistral.mistral-7b-instruct-v0:2", MistralAdapter), ("mistral.mixtral-8x7b-instruct-v0:1", MistralAdapter), ("mistral.mistral-large-2402-v1:0", MistralAdapter), + ("eu.mistral.mixtral-8x7b-instruct-v0:1", MistralAdapter), # cross-region inference + ("us.mistral.mistral-large-2402-v1:0", MistralAdapter), # cross-region inference ("mistral.mistral-medium-v8:0", MistralAdapter), # artificial ("unknown_model", None), ], From 1e1b17830613cf357dd17dd52d4cf25e30cb8a58 Mon Sep 17 00:00:00 2001 From: HaystackBot Date: Thu, 17 Oct 2024 14:07:47 +0000 Subject: [PATCH 037/229] Update the changelog --- integrations/amazon_bedrock/CHANGELOG.md | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/integrations/amazon_bedrock/CHANGELOG.md b/integrations/amazon_bedrock/CHANGELOG.md index 43ab788d3..de3647acc 100644 --- a/integrations/amazon_bedrock/CHANGELOG.md +++ b/integrations/amazon_bedrock/CHANGELOG.md @@ -1,5 +1,11 @@ # Changelog +## [integrations/amazon_bedrock-v1.0.5] - 2024-10-17 + +### 🚀 Features + +- Add prefixes to supported model patterns to allow cross region model ids (#1127) + ## [integrations/amazon_bedrock-v1.0.4] - 2024-10-16 ### 🐛 Bug Fixes From 2f12690ba6fae91168992ddaffc0a228ee49bc79 Mon Sep 17 00:00:00 2001 From: Amna Mubashar Date: Thu, 17 Oct 2024 16:21:26 +0200 Subject: [PATCH 038/229] fix: make "project-id" parameter optional during initialization (#1141) * Make project-id param optional --- .../generators/google_vertex/chat/gemini.py | 6 +++--- .../generators/google_vertex/gemini.py | 6 +++--- .../google_vertex/tests/chat/test_gemini.py | 19 +++++++++-------- .../google_vertex/tests/test_gemini.py | 21 +++++++++++-------- 4 files changed, 28 insertions(+), 24 deletions(-) diff --git a/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/chat/gemini.py b/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/chat/gemini.py index f09692daf..c52f76dc6 100644 --- a/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/chat/gemini.py +++ b/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/chat/gemini.py @@ -36,7 +36,7 @@ class VertexAIGeminiChatGenerator: from haystack.dataclasses import ChatMessage from haystack_integrations.components.generators.google_vertex import VertexAIGeminiChatGenerator - gemini_chat = VertexAIGeminiChatGenerator(project_id=project_id) + gemini_chat = VertexAIGeminiChatGenerator() messages = [ChatMessage.from_user("Tell me the name of a movie")] res = gemini_chat.run(messages) @@ -50,7 +50,7 @@ def __init__( self, *, model: str = "gemini-1.5-flash", - project_id: str, + project_id: Optional[str] = None, location: Optional[str] = None, generation_config: Optional[Union[GenerationConfig, Dict[str, Any]]] = None, safety_settings: Optional[Dict[HarmCategory, HarmBlockThreshold]] = None, @@ -65,7 +65,7 @@ def __init__( Authenticates using Google Cloud Application Default Credentials (ADCs). For more information see the official [Google documentation](https://cloud.google.com/docs/authentication/provide-credentials-adc). - :param project_id: ID of the GCP project to use. + :param project_id: ID of the GCP project to use. By default, it is set during Google Cloud authentication. :param model: Name of the model to use. For available models, see https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models. :param location: The default location to use when making API calls, if not set uses us-central-1. Defaults to None. diff --git a/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/gemini.py b/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/gemini.py index 2b1c1b477..737f2e668 100644 --- a/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/gemini.py +++ b/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/gemini.py @@ -32,7 +32,7 @@ class VertexAIGeminiGenerator: from haystack_integrations.components.generators.google_vertex import VertexAIGeminiGenerator - gemini = VertexAIGeminiGenerator(project_id=project_id) + gemini = VertexAIGeminiGenerator() result = gemini.run(parts = ["What is the most interesting thing you know?"]) for answer in result["replies"]: print(answer) @@ -54,7 +54,7 @@ def __init__( self, *, model: str = "gemini-1.5-flash", - project_id: str, + project_id: Optional[str] = None, location: Optional[str] = None, generation_config: Optional[Union[GenerationConfig, Dict[str, Any]]] = None, safety_settings: Optional[Dict[HarmCategory, HarmBlockThreshold]] = None, @@ -69,7 +69,7 @@ def __init__( Authenticates using Google Cloud Application Default Credentials (ADCs). For more information see the official [Google documentation](https://cloud.google.com/docs/authentication/provide-credentials-adc). - :param project_id: ID of the GCP project to use. + :param project_id: ID of the GCP project to use. By default, it is set during Google Cloud authentication. :param model: Name of the model to use. For available models, see https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models. :param location: The default location to use when making API calls, if not set uses us-central-1. :param generation_config: The generation config to use. diff --git a/integrations/google_vertex/tests/chat/test_gemini.py b/integrations/google_vertex/tests/chat/test_gemini.py index 6b1308dab..0d77bd9c6 100644 --- a/integrations/google_vertex/tests/chat/test_gemini.py +++ b/integrations/google_vertex/tests/chat/test_gemini.py @@ -90,14 +90,12 @@ def test_init(mock_vertexai_init, _mock_generative_model): @patch("haystack_integrations.components.generators.google_vertex.chat.gemini.GenerativeModel") def test_to_dict(_mock_vertexai_init, _mock_generative_model): - gemini = VertexAIGeminiChatGenerator( - project_id="TestID123", - ) + gemini = VertexAIGeminiChatGenerator() assert gemini.to_dict() == { "type": "haystack_integrations.components.generators.google_vertex.chat.gemini.VertexAIGeminiChatGenerator", "init_parameters": { "model": "gemini-1.5-flash", - "project_id": "TestID123", + "project_id": None, "location": None, "generation_config": None, "safety_settings": None, @@ -132,6 +130,7 @@ def test_to_dict_with_params(_mock_vertexai_init, _mock_generative_model): gemini = VertexAIGeminiChatGenerator( project_id="TestID123", + location="TestLocation", generation_config=generation_config, safety_settings=safety_settings, tools=[tool], @@ -144,7 +143,7 @@ def test_to_dict_with_params(_mock_vertexai_init, _mock_generative_model): "init_parameters": { "model": "gemini-1.5-flash", "project_id": "TestID123", - "location": None, + "location": "TestLocation", "generation_config": { "temperature": 0.5, "top_p": 0.5, @@ -194,7 +193,7 @@ def test_from_dict(_mock_vertexai_init, _mock_generative_model): { "type": "haystack_integrations.components.generators.google_vertex.chat.gemini.VertexAIGeminiChatGenerator", "init_parameters": { - "project_id": "TestID123", + "project_id": None, "model": "gemini-1.5-flash", "generation_config": None, "safety_settings": None, @@ -205,7 +204,7 @@ def test_from_dict(_mock_vertexai_init, _mock_generative_model): ) assert gemini._model_name == "gemini-1.5-flash" - assert gemini._project_id == "TestID123" + assert gemini._project_id is None assert gemini._safety_settings is None assert gemini._tools is None assert gemini._tool_config is None @@ -221,6 +220,7 @@ def test_from_dict_with_param(_mock_vertexai_init, _mock_generative_model): "type": "haystack_integrations.components.generators.google_vertex.chat.gemini.VertexAIGeminiChatGenerator", "init_parameters": { "project_id": "TestID123", + "location": "TestLocation", "model": "gemini-1.5-flash", "generation_config": { "temperature": 0.5, @@ -272,6 +272,7 @@ def test_from_dict_with_param(_mock_vertexai_init, _mock_generative_model): assert gemini._model_name == "gemini-1.5-flash" assert gemini._project_id == "TestID123" + assert gemini._location == "TestLocation" assert gemini._safety_settings == {HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH} assert repr(gemini._tools) == repr([Tool(function_declarations=[GET_CURRENT_WEATHER_FUNC])]) assert isinstance(gemini._tool_config, ToolConfig) @@ -296,7 +297,7 @@ def test_run(mock_generative_model): ChatMessage.from_system("You are a helpful assistant"), ChatMessage.from_user("What's the capital of France?"), ] - gemini = VertexAIGeminiChatGenerator(project_id="TestID123", location=None) + gemini = VertexAIGeminiChatGenerator() response = gemini.run(messages=messages) mock_model.send_message.assert_called_once() @@ -321,7 +322,7 @@ def streaming_callback(_chunk: StreamingChunk) -> None: nonlocal streaming_callback_called streaming_callback_called = True - gemini = VertexAIGeminiChatGenerator(project_id="TestID123", location=None, streaming_callback=streaming_callback) + gemini = VertexAIGeminiChatGenerator(streaming_callback=streaming_callback) messages = [ ChatMessage.from_system("You are a helpful assistant"), ChatMessage.from_user("What's the capital of France?"), diff --git a/integrations/google_vertex/tests/test_gemini.py b/integrations/google_vertex/tests/test_gemini.py index 9ec3529d7..b3d6dd5f5 100644 --- a/integrations/google_vertex/tests/test_gemini.py +++ b/integrations/google_vertex/tests/test_gemini.py @@ -78,14 +78,12 @@ def test_init(mock_vertexai_init, _mock_generative_model): @patch("haystack_integrations.components.generators.google_vertex.gemini.GenerativeModel") def test_to_dict(_mock_vertexai_init, _mock_generative_model): - gemini = VertexAIGeminiGenerator( - project_id="TestID123", - ) + gemini = VertexAIGeminiGenerator() assert gemini.to_dict() == { "type": "haystack_integrations.components.generators.google_vertex.gemini.VertexAIGeminiGenerator", "init_parameters": { "model": "gemini-1.5-flash", - "project_id": "TestID123", + "project_id": None, "location": None, "generation_config": None, "safety_settings": None, @@ -120,6 +118,7 @@ def test_to_dict_with_params(_mock_vertexai_init, _mock_generative_model): gemini = VertexAIGeminiGenerator( project_id="TestID123", + location="TestLocation", generation_config=generation_config, safety_settings=safety_settings, tools=[tool], @@ -131,7 +130,7 @@ def test_to_dict_with_params(_mock_vertexai_init, _mock_generative_model): "init_parameters": { "model": "gemini-1.5-flash", "project_id": "TestID123", - "location": None, + "location": "TestLocation", "generation_config": { "temperature": 0.5, "top_p": 0.5, @@ -181,7 +180,8 @@ def test_from_dict(_mock_vertexai_init, _mock_generative_model): { "type": "haystack_integrations.components.generators.google_vertex.gemini.VertexAIGeminiGenerator", "init_parameters": { - "project_id": "TestID123", + "project_id": None, + "location": None, "model": "gemini-1.5-flash", "generation_config": None, "safety_settings": None, @@ -194,7 +194,8 @@ def test_from_dict(_mock_vertexai_init, _mock_generative_model): ) assert gemini._model_name == "gemini-1.5-flash" - assert gemini._project_id == "TestID123" + assert gemini._project_id is None + assert gemini._location is None assert gemini._safety_settings is None assert gemini._tools is None assert gemini._tool_config is None @@ -210,6 +211,7 @@ def test_from_dict_with_param(_mock_vertexai_init, _mock_generative_model): "type": "haystack_integrations.components.generators.google_vertex.gemini.VertexAIGeminiGenerator", "init_parameters": { "project_id": "TestID123", + "location": "TestLocation", "model": "gemini-1.5-flash", "generation_config": { "temperature": 0.5, @@ -261,6 +263,7 @@ def test_from_dict_with_param(_mock_vertexai_init, _mock_generative_model): assert gemini._model_name == "gemini-1.5-flash" assert gemini._project_id == "TestID123" + assert gemini._location == "TestLocation" assert gemini._safety_settings == {HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH} assert repr(gemini._tools) == repr([Tool(function_declarations=[GET_CURRENT_WEATHER_FUNC])]) assert isinstance(gemini._generation_config, GenerationConfig) @@ -277,7 +280,7 @@ def test_run(mock_generative_model): mock_model.generate_content.return_value = MagicMock() mock_generative_model.return_value = mock_model - gemini = VertexAIGeminiGenerator(project_id="TestID123", location=None) + gemini = VertexAIGeminiGenerator() response = gemini.run(["What's the weather like today?"]) @@ -303,7 +306,7 @@ def streaming_callback(_chunk: StreamingChunk) -> None: nonlocal streaming_callback_called streaming_callback_called = True - gemini = VertexAIGeminiGenerator(model="gemini-pro", project_id="TestID123", streaming_callback=streaming_callback) + gemini = VertexAIGeminiGenerator(model="gemini-pro", streaming_callback=streaming_callback) gemini.run(["Come on, stream!"]) assert streaming_callback_called From f166d91e068136fa8ae21b1755d34ae06fa0107e Mon Sep 17 00:00:00 2001 From: Stefano Fiorucci Date: Fri, 18 Oct 2024 12:49:13 +0200 Subject: [PATCH 039/229] ci: adopt uv as installer (#1142) * try uv * pip for typing * fix docs generation * fix weaviate * revert weaviate change * no setuptools * try to see if nightly with Haystack main works * trigger * fix instructor embedders * Fix formatting in integrations/chroma/pyproject.toml --------- Co-authored-by: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com> --- integrations/amazon_bedrock/pyproject.toml | 4 +++- integrations/amazon_sagemaker/pyproject.toml | 4 +++- integrations/anthropic/pyproject.toml | 4 +++- integrations/astra/pyproject.toml | 4 +++- integrations/chroma/pyproject.toml | 3 +++ integrations/cohere/pyproject.toml | 4 +++- integrations/deepeval/pyproject.toml | 4 +++- integrations/elasticsearch/pyproject.toml | 4 +++- integrations/fastembed/pyproject.toml | 4 +++- integrations/google_ai/pyproject.toml | 4 +++- integrations/google_vertex/pyproject.toml | 4 +++- integrations/instructor_embedders/pyproject.toml | 5 ++++- integrations/jina/pyproject.toml | 4 +++- integrations/langfuse/pyproject.toml | 4 +++- integrations/llama_cpp/pyproject.toml | 4 +++- integrations/mistral/pyproject.toml | 4 +++- integrations/mongodb_atlas/pyproject.toml | 4 +++- integrations/nvidia/pyproject.toml | 4 +++- integrations/ollama/pyproject.toml | 4 +++- integrations/opensearch/pyproject.toml | 4 +++- integrations/optimum/pyproject.toml | 4 +++- integrations/pgvector/pyproject.toml | 4 +++- integrations/pinecone/pyproject.toml | 4 +++- integrations/qdrant/pyproject.toml | 4 +++- integrations/ragas/pyproject.toml | 4 +++- integrations/snowflake/pyproject.toml | 4 +++- integrations/unstructured/pyproject.toml | 4 +++- integrations/weaviate/pyproject.toml | 4 +++- 28 files changed, 85 insertions(+), 27 deletions(-) diff --git a/integrations/amazon_bedrock/pyproject.toml b/integrations/amazon_bedrock/pyproject.toml index 1298abfab..872d4933b 100644 --- a/integrations/amazon_bedrock/pyproject.toml +++ b/integrations/amazon_bedrock/pyproject.toml @@ -42,6 +42,7 @@ root = "../.." git_describe_command = 'git describe --tags --match="integrations/amazon_bedrock-v[0-9]*"' [tool.hatch.envs.default] +installer = "uv" dependencies = [ "coverage[toml]>=6.5", "pytest", @@ -60,8 +61,9 @@ docs = ["pydoc-markdown pydoc/config.yml"] python = ["3.8", "3.9", "3.10", "3.11", "3.12"] [tool.hatch.envs.lint] +installer = "uv" detached = true -dependencies = ["black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243"] +dependencies = ["pip", "black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243"] [tool.hatch.envs.lint.scripts] typing = "mypy --install-types --non-interactive --explicit-package-bases {args:src/ tests}" diff --git a/integrations/amazon_sagemaker/pyproject.toml b/integrations/amazon_sagemaker/pyproject.toml index a25b806f6..219b4c2df 100644 --- a/integrations/amazon_sagemaker/pyproject.toml +++ b/integrations/amazon_sagemaker/pyproject.toml @@ -45,6 +45,7 @@ root = "../.." git_describe_command = 'git describe --tags --match="integrations/amazon_sagemaker-v[0-9]*"' [tool.hatch.envs.default] +installer = "uv" dependencies = [ "coverage[toml]>=6.5", "pytest", @@ -65,8 +66,9 @@ docs = ["pydoc-markdown pydoc/config.yml"] python = ["3.8", "3.9", "3.10", "3.11"] [tool.hatch.envs.lint] +installer = "uv" detached = true -dependencies = ["black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243"] +dependencies = ["pip", "black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243"] [tool.hatch.envs.lint.scripts] typing = "mypy --install-types --non-interactive --explicit-package-bases {args:src/ tests}" style = ["ruff check {args:.}", "black --check --diff {args:.}"] diff --git a/integrations/anthropic/pyproject.toml b/integrations/anthropic/pyproject.toml index 987f017be..21e23fbb4 100644 --- a/integrations/anthropic/pyproject.toml +++ b/integrations/anthropic/pyproject.toml @@ -42,6 +42,7 @@ root = "../.." git_describe_command = 'git describe --tags --match="integrations/anthropic-v[0-9]*"' [tool.hatch.envs.default] +installer = "uv" dependencies = [ "coverage[toml]>=6.5", "pytest", @@ -60,8 +61,9 @@ docs = ["pydoc-markdown pydoc/config.yml"] python = ["3.8", "3.9", "3.10", "3.11", "3.12"] [tool.hatch.envs.lint] +installer = "uv" detached = true -dependencies = ["black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243"] +dependencies = ["pip", "black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243"] [tool.hatch.envs.lint.scripts] typing = "mypy --install-types --non-interactive --explicit-package-bases {args:src/ tests}" diff --git a/integrations/astra/pyproject.toml b/integrations/astra/pyproject.toml index 25bcf20b8..f9e8fe982 100644 --- a/integrations/astra/pyproject.toml +++ b/integrations/astra/pyproject.toml @@ -41,6 +41,7 @@ root = "../.." git_describe_command = 'git describe --tags --match="integrations/astra-v[0-9]*"' [tool.hatch.envs.default] +installer = "uv" dependencies = [ "coverage[toml]>=6.5", "pytest", @@ -59,8 +60,9 @@ docs = ["pydoc-markdown pydoc/config.yml"] python = ["3.8", "3.9", "3.10", "3.11"] [tool.hatch.envs.lint] +installer = "uv" detached = true -dependencies = ["black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243"] +dependencies = ["pip", "black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243"] [tool.hatch.envs.lint.scripts] typing = "mypy --install-types --non-interactive --explicit-package-bases {args:src/ tests}" style = ["ruff check {args:.}", "black --check --diff {args:.}"] diff --git a/integrations/chroma/pyproject.toml b/integrations/chroma/pyproject.toml index 27b204432..7f0943a30 100644 --- a/integrations/chroma/pyproject.toml +++ b/integrations/chroma/pyproject.toml @@ -41,6 +41,7 @@ root = "../.." git_describe_command = 'git describe --tags --match="integrations/chroma-v[0-9]*"' [tool.hatch.envs.default] +installer = "uv" dependencies = [ "coverage[toml]>=6.5", "pytest", @@ -61,8 +62,10 @@ docs = ["pydoc-markdown pydoc/config.yml"] python = ["3.9", "3.10"] [tool.hatch.envs.lint] +installer = "uv" detached = true dependencies = [ + "pip", "black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243", diff --git a/integrations/cohere/pyproject.toml b/integrations/cohere/pyproject.toml index d86165668..262b1612d 100644 --- a/integrations/cohere/pyproject.toml +++ b/integrations/cohere/pyproject.toml @@ -41,6 +41,7 @@ root = "../.." git_describe_command = 'git describe --tags --match="integrations/cohere-v[0-9]*"' [tool.hatch.envs.default] +installer = "uv" dependencies = [ "coverage[toml]>=6.5", "pytest", @@ -60,8 +61,9 @@ docs = ["pydoc-markdown pydoc/config.yml"] python = ["3.8", "3.9", "3.10", "3.11"] [tool.hatch.envs.lint] +installer = "uv" detached = true -dependencies = ["black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243"] +dependencies = ["pip", "black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243"] [tool.hatch.envs.lint.scripts] typing = "mypy --install-types --non-interactive --explicit-package-bases {args:src/ tests}" style = [ diff --git a/integrations/deepeval/pyproject.toml b/integrations/deepeval/pyproject.toml index 5d81fa0a5..6ef64387b 100644 --- a/integrations/deepeval/pyproject.toml +++ b/integrations/deepeval/pyproject.toml @@ -41,6 +41,7 @@ root = "../.." git_describe_command = 'git describe --tags --match="integrations/deepeval-v[0-9]*"' [tool.hatch.envs.default] +installer = "uv" dependencies = ["coverage[toml]>=6.5", "pytest", "pytest-rerunfailures", "haystack-pydoc-tools"] [tool.hatch.envs.default.scripts] test = "pytest {args:tests}" @@ -55,8 +56,9 @@ docs = ["pydoc-markdown pydoc/config.yml"] python = ["3.8", "3.9", "3.10", "3.11"] [tool.hatch.envs.lint] +installer = "uv" detached = true -dependencies = ["black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243"] +dependencies = ["pip", "black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243"] [tool.hatch.envs.lint.scripts] typing = "mypy --install-types --non-interactive {args:src/}" style = ["ruff check {args:.}", "black --check --diff {args:.}"] diff --git a/integrations/elasticsearch/pyproject.toml b/integrations/elasticsearch/pyproject.toml index 47b168f30..8bf01cc65 100644 --- a/integrations/elasticsearch/pyproject.toml +++ b/integrations/elasticsearch/pyproject.toml @@ -41,6 +41,7 @@ root = "../.." git_describe_command = 'git describe --tags --match="integrations/elasticsearch-v[0-9]*"' [tool.hatch.envs.default] +installer = "uv" dependencies = [ "coverage[toml]>=6.5", "pytest", @@ -61,8 +62,9 @@ docs = ["pydoc-markdown pydoc/config.yml"] python = ["3.8", "3.9", "3.10", "3.11"] [tool.hatch.envs.lint] +installer = "uv" detached = true -dependencies = ["black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243"] +dependencies = ["pip", "black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243"] [tool.hatch.envs.lint.scripts] typing = "mypy --install-types --non-interactive --explicit-package-bases {args:src/ tests}" style = ["ruff check {args:}", "black --check --diff {args:.}"] diff --git a/integrations/fastembed/pyproject.toml b/integrations/fastembed/pyproject.toml index 69aba5562..8686c9e7a 100644 --- a/integrations/fastembed/pyproject.toml +++ b/integrations/fastembed/pyproject.toml @@ -42,6 +42,7 @@ root = "../.." git_describe_command = 'git describe --tags --match="integrations/fastembed-v[0-9]*"' [tool.hatch.envs.default] +installer = "uv" dependencies = [ "coverage[toml]>=6.5", "pytest", @@ -62,8 +63,9 @@ docs = ["pydoc-markdown pydoc/config.yml"] python = ["3.8", "3.9", "3.10", "3.11", "3.12"] [tool.hatch.envs.lint] +installer = "uv" detached = true -dependencies = ["black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243", "numpy"] +dependencies = ["pip", "black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243", "numpy"] [tool.hatch.envs.lint.scripts] typing = "mypy --install-types --non-interactive --explicit-package-bases {args:src/ tests}" style = ["ruff check {args:.}", "black --check --diff {args:.}"] diff --git a/integrations/google_ai/pyproject.toml b/integrations/google_ai/pyproject.toml index 88fbcd61c..9a4a070e7 100644 --- a/integrations/google_ai/pyproject.toml +++ b/integrations/google_ai/pyproject.toml @@ -40,6 +40,7 @@ root = "../.." git_describe_command = 'git describe --tags --match="integrations/google_ai-v[0-9]*"' [tool.hatch.envs.default] +installer = "uv" dependencies = [ "coverage[toml]>=6.5", "pytest", @@ -58,8 +59,9 @@ docs = ["pydoc-markdown pydoc/config.yml"] python = ["3.8", "3.9", "3.10", "3.11"] [tool.hatch.envs.lint] +installer = "uv" detached = true -dependencies = ["black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243"] +dependencies = ["pip", "black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243"] [tool.hatch.envs.lint.scripts] typing = "mypy --install-types --non-interactive --explicit-package-bases {args:src/ tests}" style = ["ruff check {args:.}", "black --check --diff {args:.}"] diff --git a/integrations/google_vertex/pyproject.toml b/integrations/google_vertex/pyproject.toml index 51bc4ffd7..d8b7b3408 100644 --- a/integrations/google_vertex/pyproject.toml +++ b/integrations/google_vertex/pyproject.toml @@ -41,6 +41,7 @@ root = "../.." git_describe_command = 'git describe --tags --match="integrations/google_vertex-v[0-9]*"' [tool.hatch.envs.default] +installer = "uv" dependencies = [ "coverage[toml]>=6.5", "pytest", @@ -59,8 +60,9 @@ docs = ["pydoc-markdown pydoc/config.yml"] python = ["3.8", "3.9", "3.10", "3.11"] [tool.hatch.envs.lint] +installer = "uv" detached = true -dependencies = ["black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243"] +dependencies = ["pip", "black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243"] [tool.hatch.envs.lint.scripts] typing = "mypy --install-types --non-interactive --explicit-package-bases {args:src/ tests}" style = ["ruff check {args:.}", "black --check --diff {args:.}"] diff --git a/integrations/instructor_embedders/pyproject.toml b/integrations/instructor_embedders/pyproject.toml index 458c0ae0c..e165aa10d 100644 --- a/integrations/instructor_embedders/pyproject.toml +++ b/integrations/instructor_embedders/pyproject.toml @@ -30,6 +30,7 @@ dependencies = [ # Commenting some of them to not interfere with the dependencies of Haystack. #"transformers==4.20.0", "datasets>=2.2.0", + "huggingface_hub<0.26.0", #"pyarrow==8.0.0", "jsonlines", "numpy", @@ -64,6 +65,7 @@ root = "../.." git_describe_command = 'git describe --tags --match="integrations/instructor_embedders-v[0-9]*"' [tool.hatch.envs.default] +installer = "uv" dependencies = [ "coverage[toml]>=6.5", "pytest", @@ -83,8 +85,9 @@ docs = ["pydoc-markdown pydoc/config.yml"] python = ["38", "39", "310", "311"] [tool.hatch.envs.lint] +installer = "uv" detached = true -dependencies = ["black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243"] +dependencies = ["pip", "black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243"] [tool.hatch.envs.lint.scripts] typing = "mypy --install-types --non-interactive --explicit-package-bases {args:src/ tests}" diff --git a/integrations/jina/pyproject.toml b/integrations/jina/pyproject.toml index cbd8df479..c89eeacb4 100644 --- a/integrations/jina/pyproject.toml +++ b/integrations/jina/pyproject.toml @@ -43,6 +43,7 @@ root = "../.." git_describe_command = 'git describe --tags --match="integrations/jina-v[0-9]*"' [tool.hatch.envs.default] +installer = "uv" dependencies = ["coverage[toml]>=6.5", "pytest", "pytest-rerunfailures", "haystack-pydoc-tools"] [tool.hatch.envs.default.scripts] test = "pytest {args:tests}" @@ -58,8 +59,9 @@ docs = ["pydoc-markdown pydoc/config.yml"] python = ["3.8", "3.9", "3.10", "3.11"] [tool.hatch.envs.lint] +installer = "uv" detached = true -dependencies = ["black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243"] +dependencies = ["pip", "black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243"] [tool.hatch.envs.lint.scripts] typing = "mypy --install-types --non-interactive --explicit-package-bases {args:src/ tests}" style = ["ruff check {args:}", "black --check --diff {args:.}"] diff --git a/integrations/langfuse/pyproject.toml b/integrations/langfuse/pyproject.toml index 61de4596c..44397b572 100644 --- a/integrations/langfuse/pyproject.toml +++ b/integrations/langfuse/pyproject.toml @@ -42,6 +42,7 @@ root = "../.." git_describe_command = 'git describe --tags --match="integrations/langfuse-v[0-9]*"' [tool.hatch.envs.default] +installer = "uv" dependencies = [ "coverage[toml]>=6.5", "pytest", @@ -64,8 +65,9 @@ docs = ["pydoc-markdown pydoc/config.yml"] python = ["3.8", "3.9", "3.10", "3.11", "3.12"] [tool.hatch.envs.lint] +installer = "uv" detached = true -dependencies = ["black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243"] +dependencies = ["pip", "black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243"] [tool.hatch.envs.lint.scripts] typing = "mypy --install-types --non-interactive --explicit-package-bases {args:src/ tests}" diff --git a/integrations/llama_cpp/pyproject.toml b/integrations/llama_cpp/pyproject.toml index acf42d958..a33434e1b 100644 --- a/integrations/llama_cpp/pyproject.toml +++ b/integrations/llama_cpp/pyproject.toml @@ -45,6 +45,7 @@ root = "../.." git_describe_command = 'git describe --tags --match="integrations/llama_cpp-v[0-9]*"' [tool.hatch.envs.default] +installer = "uv" dependencies = [ "coverage[toml]>=6.5", "pytest", @@ -65,8 +66,9 @@ python = ["3.8", "3.9", "3.10", "3.11", "3.12"] [tool.hatch.envs.lint] +installer = "uv" detached = true -dependencies = ["black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243"] +dependencies = ["pip", "black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243"] [tool.hatch.envs.lint.scripts] typing = "mypy --install-types --non-interactive --explicit-package-bases {args:src/ tests}" diff --git a/integrations/mistral/pyproject.toml b/integrations/mistral/pyproject.toml index 16f332331..06d02c0aa 100644 --- a/integrations/mistral/pyproject.toml +++ b/integrations/mistral/pyproject.toml @@ -41,6 +41,7 @@ root = "../.." git_describe_command = 'git describe --tags --match="integrations/mistral-v[0-9]*"' [tool.hatch.envs.default] +installer = "uv" dependencies = [ "coverage[toml]>=6.5", "pytest", @@ -60,8 +61,9 @@ docs = ["pydoc-markdown pydoc/config.yml"] python = ["3.8", "3.9", "3.10", "3.11"] [tool.hatch.envs.lint] +installer = "uv" detached = true -dependencies = ["black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243"] +dependencies = ["pip", "black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243"] [tool.hatch.envs.lint.scripts] typing = "mypy --install-types --non-interactive --explicit-package-bases {args:src/ tests}" style = [ diff --git a/integrations/mongodb_atlas/pyproject.toml b/integrations/mongodb_atlas/pyproject.toml index 95ed6c03a..bdf1a2dc1 100644 --- a/integrations/mongodb_atlas/pyproject.toml +++ b/integrations/mongodb_atlas/pyproject.toml @@ -42,6 +42,7 @@ root = "../.." git_describe_command = 'git describe --tags --match="integrations/mongodb_atlas-v[0-9]*"' [tool.hatch.envs.default] +installer = "uv" dependencies = [ "coverage[toml]>=6.5", "pytest", @@ -62,8 +63,9 @@ docs = ["pydoc-markdown pydoc/config.yml"] python = ["3.8", "3.9", "3.10", "3.11", "3.12"] [tool.hatch.envs.lint] +installer = "uv" detached = true -dependencies = ["black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243"] +dependencies = ["pip", "black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243"] [tool.hatch.envs.lint.scripts] typing = "mypy --install-types --non-interactive --explicit-package-bases {args:src/ tests}" style = ["ruff check {args:.}", "black --check --diff {args:.}"] diff --git a/integrations/nvidia/pyproject.toml b/integrations/nvidia/pyproject.toml index b5c6dd205..7f0048c1b 100644 --- a/integrations/nvidia/pyproject.toml +++ b/integrations/nvidia/pyproject.toml @@ -42,6 +42,7 @@ root = "../.." git_describe_command = 'git describe --tags --match="integrations/nvidia-v[0-9]*"' [tool.hatch.envs.default] +installer = "uv" dependencies = [ "coverage[toml]>=6.5", "pytest", @@ -62,8 +63,9 @@ docs = ["pydoc-markdown pydoc/config.yml"] python = ["3.8", "3.9", "3.10", "3.11"] [tool.hatch.envs.lint] +installer = "uv" detached = true -dependencies = ["black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243"] +dependencies = ["pip", "black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243"] [tool.hatch.envs.lint.scripts] typing = "mypy --install-types --non-interactive --explicit-package-bases {args:src/ tests}" style = [ diff --git a/integrations/ollama/pyproject.toml b/integrations/ollama/pyproject.toml index bc8555140..598d1d214 100644 --- a/integrations/ollama/pyproject.toml +++ b/integrations/ollama/pyproject.toml @@ -46,6 +46,7 @@ root = "../.." git_describe_command = 'git describe --tags --match="integrations/ollama-v[0-9]*"' [tool.hatch.envs.default] +installer = "uv" dependencies = [ "coverage[toml]>=6.5", "pytest", @@ -66,8 +67,9 @@ python = ["3.8", "3.9", "3.10", "3.11", "3.12"] [tool.hatch.envs.lint] +installer = "uv" detached = true -dependencies = ["black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243"] +dependencies = ["pip", "black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243"] [tool.hatch.envs.lint.scripts] typing = "mypy --install-types --non-interactive --explicit-package-bases {args:src/ tests}" diff --git a/integrations/opensearch/pyproject.toml b/integrations/opensearch/pyproject.toml index 24f1653bd..54c194470 100644 --- a/integrations/opensearch/pyproject.toml +++ b/integrations/opensearch/pyproject.toml @@ -41,6 +41,7 @@ root = "../.." git_describe_command = 'git describe --tags --match="integrations/opensearch-v[0-9]*"' [tool.hatch.envs.default] +installer = "uv" dependencies = [ "coverage[toml]>=6.5", "pytest", @@ -63,8 +64,9 @@ docs = ["pydoc-markdown pydoc/config.yml"] python = ["3.8", "3.9", "3.10", "3.11"] [tool.hatch.envs.lint] +installer = "uv" detached = true -dependencies = ["black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243", "boto3"] +dependencies = ["pip", "black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243", "boto3"] [tool.hatch.envs.lint.scripts] typing = "mypy --install-types --non-interactive --explicit-package-bases {args:src/ tests}" style = ["ruff check {args:}", "black --check --diff {args:.}"] diff --git a/integrations/optimum/pyproject.toml b/integrations/optimum/pyproject.toml index 305af6042..6149997ed 100644 --- a/integrations/optimum/pyproject.toml +++ b/integrations/optimum/pyproject.toml @@ -54,6 +54,7 @@ root = "../.." git_describe_command = 'git describe --tags --match="integrations/optimum-v[0-9]*"' [tool.hatch.envs.default] +installer = "uv" dependencies = [ "coverage[toml]>=6.5", "pytest", @@ -76,8 +77,9 @@ python = ["3.8", "3.9", "3.10", "3.11", "3.12"] [tool.hatch.envs.lint] +installer = "uv" detached = true -dependencies = ["black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243"] +dependencies = ["pip", "black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243"] [tool.hatch.envs.lint.scripts] typing = "mypy --install-types --non-interactive --explicit-package-bases {args:src/ tests}" diff --git a/integrations/pgvector/pyproject.toml b/integrations/pgvector/pyproject.toml index 014d163bc..3f20dfbb1 100644 --- a/integrations/pgvector/pyproject.toml +++ b/integrations/pgvector/pyproject.toml @@ -42,6 +42,7 @@ root = "../.." git_describe_command = 'git describe --tags --match="integrations/pgvector-v[0-9]*"' [tool.hatch.envs.default] +installer = "uv" dependencies = [ "coverage[toml]>=6.5", "pytest", @@ -62,8 +63,9 @@ docs = ["pydoc-markdown pydoc/config.yml"] python = ["3.8", "3.9", "3.10", "3.11", "3.12"] [tool.hatch.envs.lint] +installer = "uv" detached = true -dependencies = ["black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243"] +dependencies = ["pip", "black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243"] [tool.hatch.envs.lint.scripts] typing = "mypy --install-types --non-interactive --explicit-package-bases {args:src/ tests}" style = ["ruff check {args:.}", "black --check --diff {args:.}"] diff --git a/integrations/pinecone/pyproject.toml b/integrations/pinecone/pyproject.toml index 3f2e4d6bd..1a19cb2b7 100644 --- a/integrations/pinecone/pyproject.toml +++ b/integrations/pinecone/pyproject.toml @@ -44,6 +44,7 @@ root = "../.." git_describe_command = 'git describe --tags --match="integrations/pinecone-v[0-9]*"' [tool.hatch.envs.default] +installer = "uv" dependencies = [ "coverage[toml]>=6.5", "pytest", @@ -66,8 +67,9 @@ docs = ["pydoc-markdown pydoc/config.yml"] python = ["3.8", "3.9", "3.10", "3.11"] [tool.hatch.envs.lint] +installer = "uv" detached = true -dependencies = ["black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243", "numpy"] +dependencies = ["pip", "black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243", "numpy"] [tool.hatch.envs.lint.scripts] typing = "mypy --install-types --non-interactive --explicit-package-bases {args:src/ tests}" style = ["ruff check {args:.}", "black --check --diff {args:.}"] diff --git a/integrations/qdrant/pyproject.toml b/integrations/qdrant/pyproject.toml index 898fd2dcf..f0e7e7342 100644 --- a/integrations/qdrant/pyproject.toml +++ b/integrations/qdrant/pyproject.toml @@ -44,6 +44,7 @@ root = "../.." git_describe_command = 'git describe --tags --match="integrations/qdrant-v[0-9]*"' [tool.hatch.envs.default] +installer = "uv" dependencies = ["coverage[toml]>=6.5", "pytest", "pytest-rerunfailures", "haystack-pydoc-tools"] [tool.hatch.envs.default.scripts] test = "pytest {args:tests}" @@ -58,8 +59,9 @@ docs = ["pydoc-markdown pydoc/config.yml"] python = ["3.8", "3.9", "3.10", "3.11"] [tool.hatch.envs.lint] +installer = "uv" detached = true -dependencies = ["black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243"] +dependencies = ["pip", "black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243"] [tool.hatch.envs.lint.scripts] typing = "mypy --install-types --non-interactive --explicit-package-bases {args:src/ tests}" style = ["ruff check {args:.}", "black --check --diff {args:.}"] diff --git a/integrations/ragas/pyproject.toml b/integrations/ragas/pyproject.toml index dd56e35f6..179bcce16 100644 --- a/integrations/ragas/pyproject.toml +++ b/integrations/ragas/pyproject.toml @@ -41,6 +41,7 @@ root = "../.." git_describe_command = 'git describe --tags --match="integrations/ragas-v[0-9]*"' [tool.hatch.envs.default] +installer = "uv" dependencies = [ "coverage[toml]>=6.5", "pytest", @@ -61,8 +62,9 @@ docs = ["pydoc-markdown pydoc/config.yml"] python = ["3.8", "3.9", "3.10", "3.11"] [tool.hatch.envs.lint] +installer = "uv" detached = true -dependencies = ["black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243"] +dependencies = ["pip", "black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243"] [tool.hatch.envs.lint.scripts] typing = "mypy --install-types --non-interactive {args:src/}" style = ["ruff check {args:.}", "black --check --diff {args:.}"] diff --git a/integrations/snowflake/pyproject.toml b/integrations/snowflake/pyproject.toml index 355e9d090..0e089fe79 100644 --- a/integrations/snowflake/pyproject.toml +++ b/integrations/snowflake/pyproject.toml @@ -43,6 +43,7 @@ root = "../.." git_describe_command = 'git describe --tags --match="integrations/snowflake-v[0-9]*"' [tool.hatch.envs.default] +installer = "uv" dependencies = ["coverage[toml]>=6.5", "pytest", "pytest-rerunfailures", "haystack-pydoc-tools"] [tool.hatch.envs.default.scripts] test = "pytest {args:tests}" @@ -58,8 +59,9 @@ docs = ["pydoc-markdown pydoc/config.yml"] python = ["3.8", "3.9", "3.10", "3.11"] [tool.hatch.envs.lint] +installer = "uv" detached = true -dependencies = ["black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243"] +dependencies = ["pip", "black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243"] [tool.hatch.envs.lint.scripts] typing = "mypy --install-types --non-interactive --explicit-package-bases {args:src/ tests}" style = ["ruff check {args:}", "black --check --diff {args:.}"] diff --git a/integrations/unstructured/pyproject.toml b/integrations/unstructured/pyproject.toml index 88bd463b2..14f58594c 100644 --- a/integrations/unstructured/pyproject.toml +++ b/integrations/unstructured/pyproject.toml @@ -40,6 +40,7 @@ root = "../.." git_describe_command = 'git describe --tags --match="integrations/unstructured-v[0-9]*"' [tool.hatch.envs.default] +installer = "uv" dependencies = [ "coverage[toml]>=6.5", "pytest", @@ -60,8 +61,9 @@ docs = ["pydoc-markdown pydoc/config.yml"] python = ["3.9", "3.10", "3.11"] [tool.hatch.envs.lint] +installer = "uv" detached = true -dependencies = ["black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243"] +dependencies = ["pip", "black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243"] [tool.hatch.envs.lint.scripts] typing = "mypy --install-types --non-interactive --explicit-package-bases {args:src/ tests}" style = [ diff --git a/integrations/weaviate/pyproject.toml b/integrations/weaviate/pyproject.toml index 22d3a160d..0e6a0d18d 100644 --- a/integrations/weaviate/pyproject.toml +++ b/integrations/weaviate/pyproject.toml @@ -47,6 +47,7 @@ root = "../.." git_describe_command = 'git describe --tags --match="integrations/weaviate-v[0-9]*"' [tool.hatch.envs.default] +installer = "uv" dependencies = ["coverage[toml]>=6.5", "pytest", "pytest-rerunfailures", "ipython"] [tool.hatch.envs.default.scripts] test = "pytest {args:tests}" @@ -61,8 +62,9 @@ docs = ["pydoc-markdown pydoc/config.yml"] python = ["3.8", "3.9", "3.10", "3.11", "3.12"] [tool.hatch.envs.lint] +installer = "uv" detached = true -dependencies = ["black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243"] +dependencies = ["pip", "black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243"] [tool.hatch.envs.lint.scripts] typing = "mypy --install-types --non-interactive --explicit-package-bases {args:src/ tests}" style = ["ruff check {args:}", "black --check --diff {args:.}"] From 863ace094a45588d7dbec3706039d3ad32d78632 Mon Sep 17 00:00:00 2001 From: HaystackBot Date: Fri, 18 Oct 2024 14:13:45 +0000 Subject: [PATCH 040/229] Update the changelog --- .../instructor_embedders/CHANGELOG.md | 61 +++++++++++++++++++ 1 file changed, 61 insertions(+) create mode 100644 integrations/instructor_embedders/CHANGELOG.md diff --git a/integrations/instructor_embedders/CHANGELOG.md b/integrations/instructor_embedders/CHANGELOG.md new file mode 100644 index 000000000..2c22fa90c --- /dev/null +++ b/integrations/instructor_embedders/CHANGELOG.md @@ -0,0 +1,61 @@ +# Changelog + +## [integrations/instructor_embedders-v0.4.1] - 2024-10-18 + +### 📚 Documentation + +- Disable-class-def (#556) + +### 🧪 Testing + +- Do not retry tests in `hatch run test` command (#954) + +### ⚙️ Miscellaneous Tasks + +- Retry tests to reduce flakyness (#836) +- Update ruff invocation to include check parameter (#853) +- Update ruff linting scripts and settings (#1105) +- Adopt uv as installer (#1142) + +## [integrations/instructor_embedders-v0.4.0] - 2024-02-21 + +### 🐛 Bug Fixes + +- Fix order of API docs (#447) + +This PR will also push the docs to Readme + +### 📚 Documentation + +- Update category slug (#442) + +## [integrations/instructor_embedders-v0.3.0] - 2024-02-15 + +### 🚀 Features + +- Generate API docs (#380) + +### 📚 Documentation + +- Update paths and titles (#397) + +## [integrations/instructor_embedders-v0.2.1] - 2024-01-30 + +## [integrations/instructor_embedders-v0.2.0] - 2024-01-22 + +### ⚙️ Miscellaneous Tasks + +- Replace - with _ (#114) +- chore!: Rename `model_name_or_path` to `model` in the Instructor integration (#229) + +* rename model_name_or_path in doc embedder + +* fix tests for doc embedder + +* rename model_name_or_path to model in text embedder + +* fix tests for text embedder + +* feedback + + From a73355ed5a8bdebc7883d81748ec43e76c9c8007 Mon Sep 17 00:00:00 2001 From: Stefano Fiorucci Date: Fri, 18 Oct 2024 17:35:58 +0200 Subject: [PATCH 041/229] fix: compatibility with Weaviate 4.9.0 (#1143) * compatibiliti with weaviate 4.9.0 * extend from_dict test --- .github/workflows/weaviate.yml | 2 +- integrations/weaviate/pyproject.toml | 4 ++-- integrations/weaviate/tests/test_document_store.py | 3 +++ 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/.github/workflows/weaviate.yml b/.github/workflows/weaviate.yml index 06a4bc289..36c30f069 100644 --- a/.github/workflows/weaviate.yml +++ b/.github/workflows/weaviate.yml @@ -30,7 +30,7 @@ jobs: fail-fast: false matrix: os: [ubuntu-latest] - python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"] + python-version: ["3.9", "3.10", "3.11", "3.12"] steps: - uses: actions/checkout@v4 diff --git a/integrations/weaviate/pyproject.toml b/integrations/weaviate/pyproject.toml index 0e6a0d18d..70b045bc4 100644 --- a/integrations/weaviate/pyproject.toml +++ b/integrations/weaviate/pyproject.toml @@ -7,7 +7,7 @@ name = "weaviate-haystack" dynamic = ["version"] description = "An integration of Weaviate vector database with Haystack" readme = "README.md" -requires-python = ">=3.8" +requires-python = ">=3.9" license = "Apache-2.0" keywords = [] authors = [{ name = "deepset GmbH", email = "info@deepset.ai" }] @@ -25,7 +25,7 @@ classifiers = [ ] dependencies = [ "haystack-ai", - "weaviate-client>=4.0", + "weaviate-client>=4.9", "haystack-pydoc-tools", "python-dateutil", ] diff --git a/integrations/weaviate/tests/test_document_store.py b/integrations/weaviate/tests/test_document_store.py index 190c23408..70f1e1eb2 100644 --- a/integrations/weaviate/tests/test_document_store.py +++ b/integrations/weaviate/tests/test_document_store.py @@ -265,6 +265,7 @@ def test_to_dict(self, _mock_weaviate, monkeypatch): "session_pool_connections": 20, "session_pool_maxsize": 100, "session_pool_max_retries": 3, + "session_pool_timeout": 5, }, "proxies": {"http": "http://proxy:1234", "https": None, "grpc": None}, "timeout": [30, 90], @@ -302,6 +303,7 @@ def test_from_dict(self, _mock_weaviate, monkeypatch): "connection": { "session_pool_connections": 20, "session_pool_maxsize": 20, + "session_pool_timeout": 5, }, "proxies": {"http": "http://proxy:1234"}, "timeout": [10, 60], @@ -338,6 +340,7 @@ def test_from_dict(self, _mock_weaviate, monkeypatch): assert document_store._embedded_options.grpc_port == DEFAULT_GRPC_PORT assert document_store._additional_config.connection.session_pool_connections == 20 assert document_store._additional_config.connection.session_pool_maxsize == 20 + assert document_store._additional_config.connection.session_pool_timeout == 5 def test_to_data_object(self, document_store, test_files_path): doc = Document(content="test doc") From f5995b74bc00a687f81214315e1f8291f30b33b7 Mon Sep 17 00:00:00 2001 From: HaystackBot Date: Fri, 18 Oct 2024 15:37:32 +0000 Subject: [PATCH 042/229] Update the changelog --- integrations/weaviate/CHANGELOG.md | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/integrations/weaviate/CHANGELOG.md b/integrations/weaviate/CHANGELOG.md index dacf3fef8..ec15cbeef 100644 --- a/integrations/weaviate/CHANGELOG.md +++ b/integrations/weaviate/CHANGELOG.md @@ -1,10 +1,17 @@ # Changelog -## [integrations/weaviate-v3.0.0] - 2024-09-12 +## [integrations/weaviate-v4.0.0] - 2024-10-18 + +### 🐛 Bug Fixes + +- Compatibility with Weaviate 4.9.0 (#1143) ### ⚙️ Miscellaneous Tasks - Weaviate - remove legacy filter support (#1070) +- Update changelog after removing legacy filters (#1083) +- Update ruff linting scripts and settings (#1105) +- Adopt uv as installer (#1142) ## [integrations/weaviate-v2.2.1] - 2024-09-07 @@ -58,8 +65,6 @@ This PR will also push the docs to Readme - Fix weaviate auth tests (#488) - - ### 📚 Documentation - Update category slug (#442) From 61ac2f457c53697e8f906d9f44c39995a008f140 Mon Sep 17 00:00:00 2001 From: "David S. Batista" Date: Fri, 18 Oct 2024 17:44:09 +0200 Subject: [PATCH 043/229] fix: fixing Chroma tests due `chromadb` update behaviour change (#1148) * initial import * updating tests * tryiing to get the hatch linting to run locally * simplifying test_multiple_contains - so that it also uses the fixtures * removing unused imports --- .../document_stores/chroma/filters.py | 2 +- .../chroma/tests/test_document_store.py | 20 +------------------ 2 files changed, 2 insertions(+), 20 deletions(-) diff --git a/integrations/chroma/src/haystack_integrations/document_stores/chroma/filters.py b/integrations/chroma/src/haystack_integrations/document_stores/chroma/filters.py index ef5c920a7..60046b6ad 100644 --- a/integrations/chroma/src/haystack_integrations/document_stores/chroma/filters.py +++ b/integrations/chroma/src/haystack_integrations/document_stores/chroma/filters.py @@ -27,7 +27,7 @@ class ChromaFilter: """ Dataclass to store the converted filter structure used in Chroma queries. - Following filter criterias are supported: + Following filter criteria are supported: - `ids`: A list of document IDs to filter by in Chroma collection. - `where`: A dictionary of metadata filters applied to the documents. - `where_document`: A dictionary of content-based filters applied to the documents' content. diff --git a/integrations/chroma/tests/test_document_store.py b/integrations/chroma/tests/test_document_store.py index f386b44ba..987f6d8b7 100644 --- a/integrations/chroma/tests/test_document_store.py +++ b/integrations/chroma/tests/test_document_store.py @@ -283,6 +283,7 @@ def test_contains(self, document_store: ChromaDocumentStore, filterable_docs: Li ) def test_multiple_contains(self, document_store: ChromaDocumentStore, filterable_docs: List[Document]): + filterable_docs = [doc for doc in filterable_docs if doc.content] # remove documents without content document_store.write_documents(filterable_docs) filters = { "operator": "OR", @@ -342,25 +343,6 @@ def test_nested_logical_filters(self, document_store: ChromaDocumentStore, filte ], ) - # Override inequality tests from FilterDocumentsTest - # because chroma doesn't return documents with absent meta fields - - def test_comparison_not_equal(self, document_store, filterable_docs): - """Test filter_documents() with != comparator""" - document_store.write_documents(filterable_docs) - result = document_store.filter_documents({"field": "meta.number", "operator": "!=", "value": 100}) - self.assert_documents_are_equal( - result, [d for d in filterable_docs if "number" in d.meta and d.meta.get("number") != 100] - ) - - def test_comparison_not_in(self, document_store, filterable_docs): - """Test filter_documents() with 'not in' comparator""" - document_store.write_documents(filterable_docs) - result = document_store.filter_documents({"field": "meta.number", "operator": "not in", "value": [2, 9]}) - self.assert_documents_are_equal( - result, [d for d in filterable_docs if "number" in d.meta and d.meta.get("number") not in [2, 9]] - ) - @pytest.mark.skip(reason="Filter on dataframe contents is not supported.") def test_comparison_equal_with_dataframe( self, document_store: ChromaDocumentStore, filterable_docs: List[Document] From 067adba2673304bf7f02a642393f1338c7298524 Mon Sep 17 00:00:00 2001 From: Eric Hare Date: Tue, 22 Oct 2024 06:33:36 -0700 Subject: [PATCH 044/229] feat: Update astradb integration for latest client library (#1145) * Update astradb integration for latest client library * Update CHANGELOG.md * Ruff check update * Black linting updates * Tweak to versioning for astrapy * removing CHANGELOG.MD changes since those are automatically added --------- Co-authored-by: David S. Batista --- integrations/astra/README.md | 24 ++- integrations/astra/examples/requirements.txt | 2 +- integrations/astra/pyproject.toml | 7 +- .../document_stores/astra/astra_client.py | 173 ++++++++---------- .../astra/tests/test_document_store.py | 26 +-- 5 files changed, 109 insertions(+), 123 deletions(-) diff --git a/integrations/astra/README.md b/integrations/astra/README.md index f679b7207..9ee47b8c9 100644 --- a/integrations/astra/README.md +++ b/integrations/astra/README.md @@ -6,17 +6,18 @@ ```bash pip install astra-haystack - ``` ### Local Development + install astra-haystack package locally to run integration tests: Open in gitpod: [![Open in Gitpod](https://gitpod.io/button/open-in-gitpod.svg)](https://gitpod.io/#https://github.com/Anant/astra-haystack/tree/main) -Switch Python version to 3.9 (Requires 3.8+ but not 3.12) -``` +Switch Python version to 3.9 (Requires 3.9+ but not 3.12) + +```bash pyenv install 3.9 pyenv local 3.9 ``` @@ -33,7 +34,8 @@ Install requirements `pip install -r requirements.txt` Export environment variables -``` + +```bash export ASTRA_DB_API_ENDPOINT="https://-.apps.astra.datastax.com" export ASTRA_DB_APPLICATION_TOKEN="AstraCS:..." export COLLECTION_NAME="my_collection" @@ -49,22 +51,25 @@ or This package includes Astra Document Store and Astra Embedding Retriever classes that integrate with Haystack, allowing you to easily perform document retrieval or RAG with Astra, and include those functions in Haystack pipelines. -### In order to use the Document Store directly: +### Use the Document Store Directly Import the Document Store: -``` + +```python from haystack_integrations.document_stores.astra import AstraDocumentStore from haystack.document_stores.types.policy import DuplicatePolicy ``` Load in environment variables: -``` + +```python namespace = os.environ.get("ASTRA_DB_KEYSPACE") collection_name = os.environ.get("COLLECTION_NAME", "haystack_vector_search") ``` Create the Document Store object (API Endpoint and Token are read off the environment): -``` + +```python document_store = AstraDocumentStore( collection_name=collection_name, namespace=namespace, @@ -80,7 +85,7 @@ Then you can use the document store functions like count_document below: Create the Document Store object like above, then import and create the Pipeline: -``` +```python from haystack import Pipeline pipeline = Pipeline() ``` @@ -101,7 +106,6 @@ or, > Astra DB collection '...' is detected as having the following indexing policy: {...}. This does not match the requested indexing policy for this object: {...}. In particular, there may be stricter limitations on the amount of text each string in a document can store. Consider indexing anew on a fresh collection to be able to store longer texts. - The reason for the warning is that the requested collection already exists on the database, and it is configured to [index all of its fields for search](https://docs.datastax.com/en/astra-db-serverless/api-reference/collections.html#the-indexing-option), possibly implicitly, by default. When the Haystack object tries to create it, it attempts to enforce, instead, an indexing policy tailored to the prospected usage: this is both to enable storing very long texts and to avoid indexing fields that will never be used in filtering a search (indexing those would also have a slight performance cost for writes). Typically there are two reasons why you may encounter the warning: diff --git a/integrations/astra/examples/requirements.txt b/integrations/astra/examples/requirements.txt index 710749bbe..221138666 100644 --- a/integrations/astra/examples/requirements.txt +++ b/integrations/astra/examples/requirements.txt @@ -1,4 +1,4 @@ haystack-ai sentence_transformers==2.2.2 openai==1.6.1 -astrapy>=0.7.7 \ No newline at end of file +astrapy>=1.5.0,<2.0 diff --git a/integrations/astra/pyproject.toml b/integrations/astra/pyproject.toml index f9e8fe982..5645cd5d3 100644 --- a/integrations/astra/pyproject.toml +++ b/integrations/astra/pyproject.toml @@ -7,7 +7,7 @@ name = "astra-haystack" dynamic = ["version"] description = '' readme = "README.md" -requires-python = ">=3.8" +requires-python = ">=3.9" license = "Apache-2.0" keywords = [] authors = [{ name = "Anant Corporation", email = "support@anant.us" }] @@ -15,14 +15,13 @@ classifiers = [ "License :: OSI Approved :: Apache Software License", "Development Status :: 4 - Beta", "Programming Language :: Python", - "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: Implementation :: CPython", "Programming Language :: Python :: Implementation :: PyPy", ] -dependencies = ["haystack-ai", "pydantic", "typing_extensions", "astrapy"] +dependencies = ["haystack-ai", "pydantic", "typing_extensions", "astrapy>=1.5.0,<2.0"] [project.urls] Documentation = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/astra#readme" @@ -57,7 +56,7 @@ cov = ["test-cov", "cov-report"] cov-retry = ["test-cov-retry", "cov-report"] docs = ["pydoc-markdown pydoc/config.yml"] [[tool.hatch.envs.all.matrix]] -python = ["3.8", "3.9", "3.10", "3.11"] +python = ["3.9", "3.10", "3.11"] [tool.hatch.envs.lint] installer = "uv" diff --git a/integrations/astra/src/haystack_integrations/document_stores/astra/astra_client.py b/integrations/astra/src/haystack_integrations/document_stores/astra/astra_client.py index b594f87d3..6f2289786 100644 --- a/integrations/astra/src/haystack_integrations/document_stores/astra/astra_client.py +++ b/integrations/astra/src/haystack_integrations/document_stores/astra/astra_client.py @@ -3,8 +3,9 @@ from typing import Dict, List, Optional, Union from warnings import warn -from astrapy.api import APIRequestError -from astrapy.db import AstraDB +from astrapy import DataAPIClient as AstraDBClient +from astrapy.constants import ReturnDocument +from astrapy.exceptions import CollectionAlreadyExistsException from haystack.version import __version__ as integration_version from pydantic.dataclasses import dataclass @@ -65,83 +66,78 @@ def __init__( self.similarity_function = similarity_function self.namespace = namespace - # Build the Astra DB object - self._astra_db = AstraDB( + # Get the keyspace from the collection name + my_client = AstraDBClient( + callers=[(CALLER_NAME, integration_version)], + ) + + # Get the database object + self._astra_db = my_client.get_database( api_endpoint=api_endpoint, token=token, - namespace=namespace, - caller_name=CALLER_NAME, - caller_version=integration_version, + keyspace=namespace, ) - indexing_options = {"indexing": {"deny": NON_INDEXED_FIELDS}} + indexing_options = {"deny": NON_INDEXED_FIELDS} try: # Create and connect to the newly created collection self._astra_db_collection = self._astra_db.create_collection( - collection_name=collection_name, + name=collection_name, dimension=embedding_dimension, - options=indexing_options, + indexing=indexing_options, ) - except APIRequestError: + except CollectionAlreadyExistsException as _: # possibly the collection is preexisting and has legacy # indexing settings: verify - get_coll_response = self._astra_db.get_collections(options={"explain": True}) - - collections = (get_coll_response["status"] or {}).get("collections") or [] - - preexisting = [collection for collection in collections if collection["name"] == collection_name] + preexisting = [ + coll_descriptor + for coll_descriptor in self._astra_db.list_collections() + if coll_descriptor.name == collection_name + ] if preexisting: - pre_collection = preexisting[0] # if it has no "indexing", it is a legacy collection; - # otherwise it's unexpected warn and proceed at user's risk - pre_col_options = pre_collection.get("options") or {} - if "indexing" not in pre_col_options: + # otherwise it's unexpected: warn and proceed at user's risk + pre_col_idx_opts = preexisting[0].options.indexing or {} + if not pre_col_idx_opts: warn( ( - f"Astra DB collection '{collection_name}' is " - "detected as having indexing turned on for all " - "fields (either created manually or by older " - "versions of this plugin). This implies stricter " - "limitations on the amount of text each string in a " - "document can store. Consider indexing anew on a " - "fresh collection to be able to store longer texts. " - "See https://github.com/deepset-ai/haystack-core-" - "integrations/blob/main/integrations/astra/README" - ".md#warnings-about-indexing for more details." + f"Collection '{collection_name}' is detected as " + "having indexing turned on for all fields " + "(either created manually or by older versions " + "of this plugin). This implies stricter " + "limitations on the amount of text" + " each entry can store. Consider indexing anew on a" + " fresh collection to be able to store longer texts." ), UserWarning, stacklevel=2, ) - self._astra_db_collection = self._astra_db.collection( - collection_name=collection_name, + self._astra_db_collection = self._astra_db.get_collection( + collection_name, + ) + # check if the indexing options match entirely + elif pre_col_idx_opts == indexing_options: + self._astra_db_collection = self._astra_db.get_collection( + collection_name, ) - elif pre_col_options["indexing"] != indexing_options["indexing"]: - detected_options_json = json.dumps(pre_col_options["indexing"]) - indexing_options_json = json.dumps(indexing_options["indexing"]) + else: + options_json = json.dumps(pre_col_idx_opts) warn( ( - f"Astra DB collection '{collection_name}' is " - "detected as having the following indexing policy: " - f"{detected_options_json}. This does not match the requested " - f"indexing policy for this object: {indexing_options_json}. " - "In particular, there may be stricter " - "limitations on the amount of text each string in a " - "document can store. Consider indexing anew on a " - "fresh collection to be able to store longer texts. " - "See https://github.com/deepset-ai/haystack-core-" - "integrations/blob/main/integrations/astra/README" - ".md#warnings-about-indexing for more details." + f"Collection '{collection_name}' has unexpected 'indexing'" + f" settings (options.indexing = {options_json})." + " This can result in odd behaviour when running " + " metadata filtering and/or unwarranted limitations" + " on storing long texts. Consider indexing anew on a" + " fresh collection." ), UserWarning, stacklevel=2, ) - self._astra_db_collection = self._astra_db.collection( - collection_name=collection_name, + self._collection = self._astra_db.get_collection( + collection_name, ) - else: - # the collection mismatch lies elsewhere than the indexing - raise else: # other exception raise @@ -180,7 +176,7 @@ def query( return formatted_response def _query_without_vector(self, top_k, filters=None): - query = {"filter": filters, "options": {"limit": top_k}} + query = {"filter": filters, "limit": top_k} return self.find_documents(query) @@ -196,8 +192,11 @@ def _format_query_response(responses, include_metadata, include_values): score = response.pop("$similarity", None) text = response.pop("content", None) values = response.pop("$vector", None) if include_values else [] + metadata = response if include_metadata else {} # Add all remaining fields to the metadata + rsp = Response(_id, text, values, metadata, score) + final_res.append(rsp) return QueryResponse(final_res) @@ -219,17 +218,21 @@ def find_documents(self, find_query): :param find_query: a dictionary with the query options :returns: the documents found in the index """ - response_dict = self._astra_db_collection.find( + find_cursor = self._astra_db_collection.find( filter=find_query.get("filter"), sort=find_query.get("sort"), - options=find_query.get("options"), + limit=find_query.get("limit"), projection={"*": 1}, ) - if "data" in response_dict and "documents" in response_dict["data"]: - return response_dict["data"]["documents"] - else: - logger.warning(f"No documents found: {response_dict}") + find_results = [] + for result in find_cursor: + find_results.append(result) + + if not find_results: + logger.warning("No documents found.") + + return find_results def find_one_document(self, find_query): """ @@ -238,16 +241,15 @@ def find_one_document(self, find_query): :param find_query: a dictionary with the query options :returns: the document found in the index """ - response_dict = self._astra_db_collection.find_one( + find_result = self._astra_db_collection.find_one( filter=find_query.get("filter"), - options=find_query.get("options"), projection={"*": 1}, ) - if "data" in response_dict and "document" in response_dict["data"]: - return response_dict["data"]["document"] - else: - logger.warning(f"No document found: {response_dict}") + if not find_result: + logger.warning("No document found.") + + return find_result def get_documents(self, ids: List[str], batch_size: int = 20) -> QueryResponse: """ @@ -281,15 +283,8 @@ def insert(self, documents: List[Dict]): :param documents: a list of documents to insert :returns: the IDs of the inserted documents """ - response_dict = self._astra_db_collection.insert_many(documents=documents) - - inserted_ids = ( - response_dict["status"]["insertedIds"] - if "status" in response_dict and "insertedIds" in response_dict["status"] - else [] - ) - if "errors" in response_dict: - logger.error(response_dict["errors"]) + insert_result = self._astra_db_collection.insert_many(documents=documents) + inserted_ids = [str(_id) for _id in insert_result.inserted_ids] return inserted_ids @@ -303,23 +298,21 @@ def update_document(self, document: Dict, id_key: str): """ document_id = document.pop(id_key) - response_dict = self._astra_db_collection.find_one_and_update( + update_result = self._astra_db_collection.find_one_and_update( filter={id_key: document_id}, update={"$set": document}, - options={"returnDocument": "after"}, + return_document=ReturnDocument.AFTER, projection={"*": 1}, ) document[id_key] = document_id - if "status" in response_dict and "errors" not in response_dict: - if "matchedCount" in response_dict["status"] and "modifiedCount" in response_dict["status"]: - if response_dict["status"]["matchedCount"] == 1 and response_dict["status"]["modifiedCount"] == 1: - return True + if update_result is None: + logger.warning(f"Documents {document_id} not updated in Astra DB.") - logger.warning(f"Documents {document_id} not updated in Astra DB.") + return False - return False + return True def delete( self, @@ -345,23 +338,13 @@ def delete( if "filter" in query["deleteMany"]: filter_dict = query["deleteMany"]["filter"] - deletion_counter = 0 - moredata = True - while moredata: - response_dict = self._astra_db_collection.delete_many(filter=filter_dict) - - if "moreData" not in response_dict.get("status", {}): - moredata = False + delete_result = self._astra_db_collection.delete_many(filter=filter_dict) - deletion_counter += int(response_dict["status"].get("deletedCount", 0)) + return delete_result.deleted_count - return deletion_counter - - def count_documents(self) -> int: + def count_documents(self, upper_bound: int = 10000) -> int: """ Count the number of documents in the Astra index. :returns: the number of documents in the index """ - documents_count = self._astra_db_collection.count_documents() - - return documents_count["status"]["count"] + return self._astra_db_collection.count_documents({}, upper_bound=upper_bound) diff --git a/integrations/astra/tests/test_document_store.py b/integrations/astra/tests/test_document_store.py index c4d1b6347..ef00b6b25 100644 --- a/integrations/astra/tests/test_document_store.py +++ b/integrations/astra/tests/test_document_store.py @@ -20,25 +20,14 @@ def mock_auth(monkeypatch): monkeypatch.setenv("ASTRA_DB_APPLICATION_TOKEN", "test_token") -@mock.patch("haystack_integrations.document_stores.astra.astra_client.AstraDB") +@mock.patch("haystack_integrations.document_stores.astra.astra_client.AstraDBClient") def test_init_is_lazy(_mock_client, mock_auth): # noqa _ = AstraDocumentStore() _mock_client.assert_not_called() -def test_namespace_init(mock_auth): # noqa - with mock.patch("haystack_integrations.document_stores.astra.astra_client.AstraDB") as client: - _ = AstraDocumentStore().index - assert "namespace" in client.call_args.kwargs - assert client.call_args.kwargs["namespace"] is None - - _ = AstraDocumentStore(namespace="foo").index - assert "namespace" in client.call_args.kwargs - assert client.call_args.kwargs["namespace"] == "foo" - - def test_to_dict(mock_auth): # noqa - with mock.patch("haystack_integrations.document_stores.astra.astra_client.AstraDB"): + with mock.patch("haystack_integrations.document_stores.astra.astra_client.AstraDBClient"): ds = AstraDocumentStore() result = ds.to_dict() assert result["type"] == "haystack_integrations.document_stores.astra.document_store.AstraDocumentStore" @@ -206,6 +195,17 @@ def test_filter_documents_by_id(self, document_store): result = document_store.filter_documents(filters={"field": "id", "operator": "==", "value": "1"}) self.assert_documents_are_equal(result, [docs[0]]) + def test_filter_documents_by_in_operator(self, document_store): + docs = [Document(id="3", content="test doc 3"), Document(id="4", content="test doc 4")] + document_store.write_documents(docs) + result = document_store.filter_documents(filters={"field": "id", "operator": "in", "value": ["3", "4"]}) + + # Sort the result in place by the id field + result.sort(key=lambda x: x.id) + + self.assert_documents_are_equal([result[0]], [docs[0]]) + self.assert_documents_are_equal([result[1]], [docs[1]]) + @pytest.mark.skip(reason="Unsupported filter operator not.") def test_not_operator(self, document_store, filterable_docs): pass From c09812c934147db5c9a8711758dcc2bada867de8 Mon Sep 17 00:00:00 2001 From: HaystackBot Date: Tue, 22 Oct 2024 13:50:54 +0000 Subject: [PATCH 045/229] Update the changelog --- integrations/astra/CHANGELOG.md | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/integrations/astra/CHANGELOG.md b/integrations/astra/CHANGELOG.md index 79bb9e35d..fff6cb65f 100644 --- a/integrations/astra/CHANGELOG.md +++ b/integrations/astra/CHANGELOG.md @@ -1,5 +1,16 @@ # Changelog +## [integrations/astra-v0.10.0] - 2024-10-22 + +### 🚀 Features + +- Update astradb integration for latest client library (#1145) + +### ⚙️ Miscellaneous Tasks + +- Update ruff linting scripts and settings (#1105) +- Adopt uv as installer (#1142) + ## [integrations/astra-v0.9.3] - 2024-09-12 ### 🐛 Bug Fixes @@ -23,9 +34,7 @@ ### 🐛 Bug Fixes - Fix astra nightly - - Fix typing checks - - `Astra` - Fallback to default filter policy when deserializing retrievers without the init parameter (#896) ### ⚙️ Miscellaneous Tasks @@ -50,8 +59,6 @@ - Fix haystack-ai pin (#649) - - ## [integrations/astra-v0.5.0] - 2024-03-18 ### 📚 Documentation @@ -75,8 +82,6 @@ This PR will also push the docs to Readme - Fix integration tests (#450) - - ## [integrations/astra-v0.4.0] - 2024-02-20 ### 📚 Documentation From 5b3b6c15b03c8ef3da42b3e84f7f252d15de150a Mon Sep 17 00:00:00 2001 From: Amna Mubashar Date: Wed, 23 Oct 2024 11:36:21 +0200 Subject: [PATCH 046/229] fix: make project-id optional in all VertexAI generators (#1147) * Make project-id optional in vertexai gens --- .../generators/google_vertex/captioner.py | 8 +++++--- .../generators/google_vertex/code_generator.py | 8 +++++--- .../generators/google_vertex/image_generator.py | 13 ++++++++++--- .../google_vertex/question_answering.py | 8 +++++--- .../generators/google_vertex/text_generator.py | 8 +++++--- .../google_vertex/tests/test_captioner.py | 15 ++++++--------- .../google_vertex/tests/test_code_generator.py | 15 ++++++--------- .../google_vertex/tests/test_image_generator.py | 8 +++----- .../tests/test_question_answering.py | 8 +++----- .../google_vertex/tests/test_text_generator.py | 14 +++++--------- 10 files changed, 53 insertions(+), 52 deletions(-) diff --git a/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/captioner.py b/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/captioner.py index 14102eb4b..ff8ce497b 100644 --- a/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/captioner.py +++ b/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/captioner.py @@ -25,7 +25,7 @@ class VertexAIImageCaptioner: from haystack.dataclasses.byte_stream import ByteStream from haystack_integrations.components.generators.google_vertex import VertexAIImageCaptioner - captioner = VertexAIImageCaptioner(project_id=project_id) + captioner = VertexAIImageCaptioner() image = ByteStream( data=requests.get( @@ -41,14 +41,16 @@ class VertexAIImageCaptioner: ``` """ - def __init__(self, *, model: str = "imagetext", project_id: str, location: Optional[str] = None, **kwargs): + def __init__( + self, *, model: str = "imagetext", project_id: Optional[str] = None, location: Optional[str] = None, **kwargs + ): """ Generate image captions using a Google Vertex AI model. Authenticates using Google Cloud Application Default Credentials (ADCs). For more information see the official [Google documentation](https://cloud.google.com/docs/authentication/provide-credentials-adc). - :param project_id: ID of the GCP project to use. + :param project_id: ID of the GCP project to use. By default, it is set during Google Cloud authentication. :param model: Name of the model to use. :param location: The default location to use when making API calls, if not set uses us-central-1. Defaults to None. diff --git a/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/code_generator.py b/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/code_generator.py index c39c7f88b..096e642dd 100644 --- a/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/code_generator.py +++ b/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/code_generator.py @@ -20,7 +20,7 @@ class VertexAICodeGenerator: ```python from haystack_integrations.components.generators.google_vertex import VertexAICodeGenerator - generator = VertexAICodeGenerator(project_id=project_id) + generator = VertexAICodeGenerator() result = generator.run(prefix="def to_json(data):") @@ -45,14 +45,16 @@ class VertexAICodeGenerator: ``` """ - def __init__(self, *, model: str = "code-bison", project_id: str, location: Optional[str] = None, **kwargs): + def __init__( + self, *, model: str = "code-bison", project_id: Optional[str] = None, location: Optional[str] = None, **kwargs + ): """ Generate code using a Google Vertex AI model. Authenticates using Google Cloud Application Default Credentials (ADCs). For more information see the official [Google documentation](https://cloud.google.com/docs/authentication/provide-credentials-adc). - :param project_id: ID of the GCP project to use. + :param project_id: ID of the GCP project to use. By default, it is set during Google Cloud authentication. :param model: Name of the model to use. :param location: The default location to use when making API calls, if not set uses us-central-1. :param kwargs: Additional keyword arguments to pass to the model. diff --git a/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/image_generator.py b/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/image_generator.py index 0534a20f2..9301221b5 100644 --- a/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/image_generator.py +++ b/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/image_generator.py @@ -24,20 +24,27 @@ class VertexAIImageGenerator: from haystack_integrations.components.generators.google_vertex import VertexAIImageGenerator - generator = VertexAIImageGenerator(project_id=project_id) + generator = VertexAIImageGenerator() result = generator.run(prompt="Generate an image of a cute cat") result["images"][0].to_file(Path("my_image.png")) ``` """ - def __init__(self, *, model: str = "imagegeneration", project_id: str, location: Optional[str] = None, **kwargs): + def __init__( + self, + *, + model: str = "imagegeneration", + project_id: Optional[str] = None, + location: Optional[str] = None, + **kwargs, + ): """ Generates images using a Google Vertex AI model. Authenticates using Google Cloud Application Default Credentials (ADCs). For more information see the official [Google documentation](https://cloud.google.com/docs/authentication/provide-credentials-adc). - :param project_id: ID of the GCP project to use. + :param project_id: ID of the GCP project to use. By default, it is set during Google Cloud authentication. :param model: Name of the model to use. :param location: The default location to use when making API calls, if not set uses us-central-1. :param kwargs: Additional keyword arguments to pass to the model. diff --git a/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/question_answering.py b/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/question_answering.py index 392a41e00..38eeb7c62 100644 --- a/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/question_answering.py +++ b/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/question_answering.py @@ -23,7 +23,7 @@ class VertexAIImageQA: from haystack.dataclasses.byte_stream import ByteStream from haystack_integrations.components.generators.google_vertex import VertexAIImageQA - qa = VertexAIImageQA(project_id=project_id) + qa = VertexAIImageQA() image = ByteStream.from_file_path("dog.jpg") @@ -35,14 +35,16 @@ class VertexAIImageQA: ``` """ - def __init__(self, *, model: str = "imagetext", project_id: str, location: Optional[str] = None, **kwargs): + def __init__( + self, *, model: str = "imagetext", project_id: Optional[str] = None, location: Optional[str] = None, **kwargs + ): """ Answers questions about an image using a Google Vertex AI model. Authenticates using Google Cloud Application Default Credentials (ADCs). For more information see the official [Google documentation](https://cloud.google.com/docs/authentication/provide-credentials-adc). - :param project_id: ID of the GCP project to use. + :param project_id: ID of the GCP project to use. By default, it is set during Google Cloud authentication. :param model: Name of the model to use. :param location: The default location to use when making API calls, if not set uses us-central-1. :param kwargs: Additional keyword arguments to pass to the model. diff --git a/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/text_generator.py b/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/text_generator.py index 59061d91c..4f69dfb18 100644 --- a/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/text_generator.py +++ b/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/text_generator.py @@ -25,7 +25,7 @@ class VertexAITextGenerator: ```python from haystack_integrations.components.generators.google_vertex import VertexAITextGenerator - generator = VertexAITextGenerator(project_id=project_id) + generator = VertexAITextGenerator() res = generator.run("Tell me a good interview question for a software engineer.") print(res["replies"][0]) @@ -45,14 +45,16 @@ class VertexAITextGenerator: ``` """ - def __init__(self, *, model: str = "text-bison", project_id: str, location: Optional[str] = None, **kwargs): + def __init__( + self, *, model: str = "text-bison", project_id: Optional[str] = None, location: Optional[str] = None, **kwargs + ): """ Generate text using a Google Vertex AI model. Authenticates using Google Cloud Application Default Credentials (ADCs). For more information see the official [Google documentation](https://cloud.google.com/docs/authentication/provide-credentials-adc). - :param project_id: ID of the GCP project to use. + :param project_id: ID of the GCP project to use. By default, it is set during Google Cloud authentication. :param model: Name of the model to use. :param location: The default location to use when making API calls, if not set uses us-central-1. :param kwargs: Additional keyword arguments to pass to the model. diff --git a/integrations/google_vertex/tests/test_captioner.py b/integrations/google_vertex/tests/test_captioner.py index 26249dbee..3d849c738 100644 --- a/integrations/google_vertex/tests/test_captioner.py +++ b/integrations/google_vertex/tests/test_captioner.py @@ -22,14 +22,12 @@ def test_init(mock_model_class, mock_vertexai): @patch("haystack_integrations.components.generators.google_vertex.captioner.vertexai") @patch("haystack_integrations.components.generators.google_vertex.captioner.ImageTextModel") def test_to_dict(_mock_model_class, _mock_vertexai): - captioner = VertexAIImageCaptioner( - model="imagetext", project_id="myproject-123456", number_of_results=1, language="it" - ) + captioner = VertexAIImageCaptioner(model="imagetext", number_of_results=1, language="it") assert captioner.to_dict() == { "type": "haystack_integrations.components.generators.google_vertex.captioner.VertexAIImageCaptioner", "init_parameters": { "model": "imagetext", - "project_id": "myproject-123456", + "project_id": None, "location": None, "number_of_results": 1, "language": "it", @@ -45,14 +43,15 @@ def test_from_dict(_mock_model_class, _mock_vertexai): "type": "haystack_integrations.components.generators.google_vertex.captioner.VertexAIImageCaptioner", "init_parameters": { "model": "imagetext", - "project_id": "myproject-123456", + "project_id": None, + "location": None, "number_of_results": 1, "language": "it", }, } ) assert captioner._model_name == "imagetext" - assert captioner._project_id == "myproject-123456" + assert captioner._project_id is None assert captioner._location is None assert captioner._kwargs == {"number_of_results": 1, "language": "it"} assert captioner._model is not None @@ -63,9 +62,7 @@ def test_from_dict(_mock_model_class, _mock_vertexai): def test_run_calls_get_captions(mock_model_class, _mock_vertexai): mock_model = Mock() mock_model_class.from_pretrained.return_value = mock_model - captioner = VertexAIImageCaptioner( - model="imagetext", project_id="myproject-123456", number_of_results=1, language="it" - ) + captioner = VertexAIImageCaptioner(model="imagetext", number_of_results=1, language="it") image = ByteStream(data=b"image data") captioner.run(image=image) diff --git a/integrations/google_vertex/tests/test_code_generator.py b/integrations/google_vertex/tests/test_code_generator.py index 129954062..132f4c945 100644 --- a/integrations/google_vertex/tests/test_code_generator.py +++ b/integrations/google_vertex/tests/test_code_generator.py @@ -22,14 +22,12 @@ def test_init(mock_model_class, mock_vertexai): @patch("haystack_integrations.components.generators.google_vertex.code_generator.vertexai") @patch("haystack_integrations.components.generators.google_vertex.code_generator.CodeGenerationModel") def test_to_dict(_mock_model_class, _mock_vertexai): - generator = VertexAICodeGenerator( - model="code-bison", project_id="myproject-123456", candidate_count=3, temperature=0.5 - ) + generator = VertexAICodeGenerator(model="code-bison", candidate_count=3, temperature=0.5) assert generator.to_dict() == { "type": "haystack_integrations.components.generators.google_vertex.code_generator.VertexAICodeGenerator", "init_parameters": { "model": "code-bison", - "project_id": "myproject-123456", + "project_id": None, "location": None, "candidate_count": 3, "temperature": 0.5, @@ -45,14 +43,15 @@ def test_from_dict(_mock_model_class, _mock_vertexai): "type": "haystack_integrations.components.generators.google_vertex.code_generator.VertexAICodeGenerator", "init_parameters": { "model": "code-bison", - "project_id": "myproject-123456", + "project_id": None, + "location": None, "candidate_count": 2, "temperature": 0.5, }, } ) assert generator._model_name == "code-bison" - assert generator._project_id == "myproject-123456" + assert generator._project_id is None assert generator._location is None assert generator._kwargs == {"candidate_count": 2, "temperature": 0.5} assert generator._model is not None @@ -64,9 +63,7 @@ def test_run_calls_predict(mock_model_class, _mock_vertexai): mock_model = Mock() mock_model.predict.return_value = TextGenerationResponse("answer", None) mock_model_class.from_pretrained.return_value = mock_model - generator = VertexAICodeGenerator( - model="code-bison", project_id="myproject-123456", candidate_count=1, temperature=0.5 - ) + generator = VertexAICodeGenerator(model="code-bison", candidate_count=1, temperature=0.5) prefix = "def print_json(data):\n" generator.run(prefix=prefix) diff --git a/integrations/google_vertex/tests/test_image_generator.py b/integrations/google_vertex/tests/test_image_generator.py index 6cd42a11c..860c1ec43 100644 --- a/integrations/google_vertex/tests/test_image_generator.py +++ b/integrations/google_vertex/tests/test_image_generator.py @@ -30,7 +30,6 @@ def test_init(mock_model_class, mock_vertexai): def test_to_dict(_mock_model_class, _mock_vertexai): generator = VertexAIImageGenerator( model="imagetext", - project_id="myproject-123456", guidance_scale=12, number_of_images=3, ) @@ -38,7 +37,7 @@ def test_to_dict(_mock_model_class, _mock_vertexai): "type": "haystack_integrations.components.generators.google_vertex.image_generator.VertexAIImageGenerator", "init_parameters": { "model": "imagetext", - "project_id": "myproject-123456", + "project_id": None, "location": None, "guidance_scale": 12, "number_of_images": 3, @@ -54,7 +53,7 @@ def test_from_dict(_mock_model_class, _mock_vertexai): "type": "haystack_integrations.components.generators.google_vertex.image_generator.VertexAIImageGenerator", "init_parameters": { "model": "imagetext", - "project_id": "myproject-123456", + "project_id": None, "location": None, "guidance_scale": 12, "number_of_images": 3, @@ -62,7 +61,7 @@ def test_from_dict(_mock_model_class, _mock_vertexai): } ) assert generator._model_name == "imagetext" - assert generator._project_id == "myproject-123456" + assert generator._project_id is None assert generator._location is None assert generator._kwargs == { "guidance_scale": 12, @@ -78,7 +77,6 @@ def test_run_calls_generate_images(mock_model_class, _mock_vertexai): mock_model_class.from_pretrained.return_value = mock_model generator = VertexAIImageGenerator( model="imagetext", - project_id="myproject-123456", guidance_scale=12, number_of_images=3, ) diff --git a/integrations/google_vertex/tests/test_question_answering.py b/integrations/google_vertex/tests/test_question_answering.py index 3f414f0e0..a36e47b6f 100644 --- a/integrations/google_vertex/tests/test_question_answering.py +++ b/integrations/google_vertex/tests/test_question_answering.py @@ -26,14 +26,13 @@ def test_init(mock_model_class, mock_vertexai): def test_to_dict(_mock_model_class, _mock_vertexai): generator = VertexAIImageQA( model="imagetext", - project_id="myproject-123456", number_of_results=3, ) assert generator.to_dict() == { "type": "haystack_integrations.components.generators.google_vertex.question_answering.VertexAIImageQA", "init_parameters": { "model": "imagetext", - "project_id": "myproject-123456", + "project_id": None, "location": None, "number_of_results": 3, }, @@ -48,14 +47,14 @@ def test_from_dict(_mock_model_class, _mock_vertexai): "type": "haystack_integrations.components.generators.google_vertex.question_answering.VertexAIImageQA", "init_parameters": { "model": "imagetext", - "project_id": "myproject-123456", + "project_id": None, "location": None, "number_of_results": 3, }, } ) assert generator._model_name == "imagetext" - assert generator._project_id == "myproject-123456" + assert generator._project_id is None assert generator._location is None assert generator._kwargs == {"number_of_results": 3} @@ -68,7 +67,6 @@ def test_run_calls_ask_question(mock_model_class, _mock_vertexai): mock_model_class.from_pretrained.return_value = mock_model generator = VertexAIImageQA( model="imagetext", - project_id="myproject-123456", number_of_results=3, ) diff --git a/integrations/google_vertex/tests/test_text_generator.py b/integrations/google_vertex/tests/test_text_generator.py index 3e5248dc7..cc3f15312 100644 --- a/integrations/google_vertex/tests/test_text_generator.py +++ b/integrations/google_vertex/tests/test_text_generator.py @@ -24,14 +24,12 @@ def test_init(mock_model_class, mock_vertexai): @patch("haystack_integrations.components.generators.google_vertex.text_generator.TextGenerationModel") def test_to_dict(_mock_model_class, _mock_vertexai): grounding_source = GroundingSource.VertexAISearch("1234", "us-central-1") - generator = VertexAITextGenerator( - model="text-bison", project_id="myproject-123456", temperature=0.2, grounding_source=grounding_source - ) + generator = VertexAITextGenerator(model="text-bison", temperature=0.2, grounding_source=grounding_source) assert generator.to_dict() == { "type": "haystack_integrations.components.generators.google_vertex.text_generator.VertexAITextGenerator", "init_parameters": { "model": "text-bison", - "project_id": "myproject-123456", + "project_id": None, "location": None, "temperature": 0.2, "grounding_source": { @@ -55,7 +53,7 @@ def test_from_dict(_mock_model_class, _mock_vertexai): "type": "haystack_integrations.components.generators.google_vertex.text_generator.VertexAITextGenerator", "init_parameters": { "model": "text-bison", - "project_id": "myproject-123456", + "project_id": None, "location": None, "temperature": 0.2, "grounding_source": { @@ -71,7 +69,7 @@ def test_from_dict(_mock_model_class, _mock_vertexai): } ) assert generator._model_name == "text-bison" - assert generator._project_id == "myproject-123456" + assert generator._project_id is None assert generator._location is None assert generator._kwargs == { "temperature": 0.2, @@ -86,9 +84,7 @@ def test_run_calls_get_captions(mock_model_class, _mock_vertexai): mock_model.predict.return_value = MagicMock() mock_model_class.from_pretrained.return_value = mock_model grounding_source = GroundingSource.VertexAISearch("1234", "us-central-1") - generator = VertexAITextGenerator( - model="text-bison", project_id="myproject-123456", temperature=0.2, grounding_source=grounding_source - ) + generator = VertexAITextGenerator(model="text-bison", temperature=0.2, grounding_source=grounding_source) prompt = "What is the answer?" generator.run(prompt=prompt) From f57ec1aed5100feb3f0248f20240890aec1998da Mon Sep 17 00:00:00 2001 From: HaystackBot Date: Wed, 23 Oct 2024 12:07:44 +0000 Subject: [PATCH 047/229] Update the changelog --- integrations/google_vertex/CHANGELOG.md | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/integrations/google_vertex/CHANGELOG.md b/integrations/google_vertex/CHANGELOG.md index 23ab51b3d..ed2cc3c3b 100644 --- a/integrations/google_vertex/CHANGELOG.md +++ b/integrations/google_vertex/CHANGELOG.md @@ -1,5 +1,16 @@ # Changelog +## [integrations/google_vertex-v2.2.0] - 2024-10-23 + +### 🐛 Bug Fixes + +- Make "project-id" parameter optional during initialization (#1141) +- Make project-id optional in all VertexAI generators (#1147) + +### ⚙️ Miscellaneous Tasks + +- Adopt uv as installer (#1142) + ## [integrations/google_vertex-v2.1.0] - 2024-10-04 ### 🚀 Features From ae207f0b94d0074e348b9458ccf46d0f36c3e850 Mon Sep 17 00:00:00 2001 From: Stefano Fiorucci Date: Wed, 23 Oct 2024 14:12:25 +0200 Subject: [PATCH 048/229] refactor: avoid downloading tokenizer if `truncate` is `False` (#1152) * avoid downloading tokenizer if truncate is False * fix --- .../generators/amazon_bedrock/generator.py | 19 ++++++++++--------- .../amazon_bedrock/tests/test_generator.py | 8 ++++++++ 2 files changed, 18 insertions(+), 9 deletions(-) diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/generator.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/generator.py index c6c814de4..941fdbf71 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/generator.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/generator.py @@ -152,15 +152,16 @@ def resolve_secret(secret: Optional[Secret]) -> Optional[str]: # We pop the model_max_length as it is not sent to the model but used to truncate the prompt if needed model_max_length = kwargs.get("model_max_length", 4096) - # Truncate prompt if prompt tokens > model_max_length-max_length - # (max_length is the length of the generated text) - # we use GPT2 tokenizer which will likely provide good token count approximation - - self.prompt_handler = DefaultPromptHandler( - tokenizer="gpt2", - model_max_length=model_max_length, - max_length=self.max_length or 100, - ) + # we initialize the prompt handler only if truncate is True: we avoid unnecessarily downloading the tokenizer + if self.truncate: + # Truncate prompt if prompt tokens > model_max_length-max_length + # (max_length is the length of the generated text) + # we use GPT2 tokenizer which will likely provide good token count approximation + self.prompt_handler = DefaultPromptHandler( + tokenizer="gpt2", + model_max_length=model_max_length, + max_length=self.max_length or 100, + ) model_adapter_cls = self.get_model_adapter(model=model) if not model_adapter_cls: diff --git a/integrations/amazon_bedrock/tests/test_generator.py b/integrations/amazon_bedrock/tests/test_generator.py index 79246b4aa..be645218e 100644 --- a/integrations/amazon_bedrock/tests/test_generator.py +++ b/integrations/amazon_bedrock/tests/test_generator.py @@ -108,6 +108,14 @@ def test_constructor_prompt_handler_initialized(mock_boto3_session, mock_prompt_ assert layer.prompt_handler.model_max_length == 4096 +def test_prompt_handler_absent_when_truncate_false(mock_boto3_session): + """ + Test that the prompt_handler is not initialized when truncate is set to False. + """ + generator = AmazonBedrockGenerator(model="anthropic.claude-v2", truncate=False) + assert not hasattr(generator, "prompt_handler") + + def test_constructor_with_model_kwargs(mock_boto3_session): """ Test that model_kwargs are correctly set in the constructor From 5034f96b23cca84c1b6f4a60eb9402d748a33cec Mon Sep 17 00:00:00 2001 From: HaystackBot Date: Wed, 23 Oct 2024 12:13:49 +0000 Subject: [PATCH 049/229] Update the changelog --- integrations/amazon_bedrock/CHANGELOG.md | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/integrations/amazon_bedrock/CHANGELOG.md b/integrations/amazon_bedrock/CHANGELOG.md index de3647acc..1068e870a 100644 --- a/integrations/amazon_bedrock/CHANGELOG.md +++ b/integrations/amazon_bedrock/CHANGELOG.md @@ -1,5 +1,15 @@ # Changelog +## [integrations/amazon_bedrock-v1.1.0] - 2024-10-23 + +### 🚜 Refactor + +- Avoid downloading tokenizer if `truncate` is `False` (#1152) + +### ⚙️ Miscellaneous Tasks + +- Adopt uv as installer (#1142) + ## [integrations/amazon_bedrock-v1.0.5] - 2024-10-17 ### 🚀 Features From 2602bc1e4371c3d46c4ff62073aea718f65c12d9 Mon Sep 17 00:00:00 2001 From: Stefano Fiorucci Date: Mon, 28 Oct 2024 15:02:01 +0100 Subject: [PATCH 050/229] elasticsearch - allow passing headers (#1156) --- .../elasticsearch/document_store.py | 5 ++++- .../elasticsearch/tests/test_document_store.py | 14 ++++++++++++++ 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/integrations/elasticsearch/src/haystack_integrations/document_stores/elasticsearch/document_store.py b/integrations/elasticsearch/src/haystack_integrations/document_stores/elasticsearch/document_store.py index 734e2d2b8..8dfb07919 100644 --- a/integrations/elasticsearch/src/haystack_integrations/document_stores/elasticsearch/document_store.py +++ b/integrations/elasticsearch/src/haystack_integrations/document_stores/elasticsearch/document_store.py @@ -105,9 +105,12 @@ def __init__( @property def client(self) -> Elasticsearch: if self._client is None: + headers = self._kwargs.pop("headers", {}) + headers["user-agent"] = f"haystack-py-ds/{haystack_version}" + client = Elasticsearch( self._hosts, - headers={"user-agent": f"haystack-py-ds/{haystack_version}"}, + headers=headers, **self._kwargs, ) # Check client connection, this will raise if not connected diff --git a/integrations/elasticsearch/tests/test_document_store.py b/integrations/elasticsearch/tests/test_document_store.py index 51a19b641..d636ff027 100644 --- a/integrations/elasticsearch/tests/test_document_store.py +++ b/integrations/elasticsearch/tests/test_document_store.py @@ -22,6 +22,20 @@ def test_init_is_lazy(_mock_es_client): _mock_es_client.assert_not_called() +@patch("haystack_integrations.document_stores.elasticsearch.document_store.Elasticsearch") +def test_headers_are_supported(_mock_es_client): + _ = ElasticsearchDocumentStore(hosts="testhost", headers={"header1": "value1", "header2": "value2"}).client + + assert _mock_es_client.call_count == 1 + _, kwargs = _mock_es_client.call_args + + headers_found = kwargs["headers"] + assert headers_found["header1"] == "value1" + assert headers_found["header2"] == "value2" + + assert headers_found["user-agent"].startswith("haystack-py-ds/") + + @patch("haystack_integrations.document_stores.elasticsearch.document_store.Elasticsearch") def test_to_dict(_mock_elasticsearch_client): document_store = ElasticsearchDocumentStore(hosts="some hosts") From 7eb062c63dcb6dfa6bbaa4ba393e76ff39715ac4 Mon Sep 17 00:00:00 2001 From: HaystackBot Date: Mon, 28 Oct 2024 14:04:20 +0000 Subject: [PATCH 051/229] Update the changelog --- integrations/elasticsearch/CHANGELOG.md | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/integrations/elasticsearch/CHANGELOG.md b/integrations/elasticsearch/CHANGELOG.md index 5d2b66470..bd8bff63c 100644 --- a/integrations/elasticsearch/CHANGELOG.md +++ b/integrations/elasticsearch/CHANGELOG.md @@ -1,5 +1,13 @@ # Changelog +## [integrations/elasticsearch-v1.0.1] - 2024-10-28 + +### ⚙️ Miscellaneous Tasks + +- Update changelog after removing legacy filters (#1083) +- Update ruff linting scripts and settings (#1105) +- Adopt uv as installer (#1142) + ## [integrations/elasticsearch-v1.0.0] - 2024-09-12 ### 🚀 Features @@ -69,8 +77,6 @@ This PR will also push the docs to Readme - Fix project urls (#96) - - ### 🚜 Refactor - Use `hatch_vcs` to manage integrations versioning (#103) @@ -81,15 +87,12 @@ This PR will also push the docs to Readme - Fix import and increase version (#77) - - ## [integrations/elasticsearch-v0.1.0] - 2023-12-04 ### 🐛 Bug Fixes - Fix license headers - ## [integrations/elasticsearch-v0.0.2] - 2023-11-29 From 3220330f64baa90670dd4bb2a84661b674e4ff08 Mon Sep 17 00:00:00 2001 From: Stefano Fiorucci Date: Mon, 28 Oct 2024 17:11:31 +0100 Subject: [PATCH 052/229] introduce stalebot (#1157) --- .github/workflows/CI_stale.yml | 15 +++++++++++++++ 1 file changed, 15 insertions(+) create mode 100644 .github/workflows/CI_stale.yml diff --git a/.github/workflows/CI_stale.yml b/.github/workflows/CI_stale.yml new file mode 100644 index 000000000..5a4b3b467 --- /dev/null +++ b/.github/workflows/CI_stale.yml @@ -0,0 +1,15 @@ +name: 'Stalebot' +on: + schedule: + - cron: '30 1 * * *' + +jobs: + makestale: + runs-on: ubuntu-latest + steps: + - uses: actions/stale@v9 + with: + any-of-labels: 'community-triage' + stale-pr-message: 'This issue is stale because it has been open 30 days with no activity. Remove stale label or comment or this will be closed in 10 days.' + days-before-stale: 30 + days-before-close: 10 \ No newline at end of file From 6e8ee96b7d30fbc8a1c1cc73c27ab3078ddcb18e Mon Sep 17 00:00:00 2001 From: Stefano Fiorucci Date: Tue, 29 Oct 2024 13:01:37 +0100 Subject: [PATCH 053/229] remove index param from some methods (#1160) --- .../document_stores/qdrant/document_store.py | 20 ++++++------------- 1 file changed, 6 insertions(+), 14 deletions(-) diff --git a/integrations/qdrant/src/haystack_integrations/document_stores/qdrant/document_store.py b/integrations/qdrant/src/haystack_integrations/document_stores/qdrant/document_store.py index 88afd8f65..c8cb9a393 100644 --- a/integrations/qdrant/src/haystack_integrations/document_stores/qdrant/document_store.py +++ b/integrations/qdrant/src/haystack_integrations/document_stores/qdrant/document_store.py @@ -362,7 +362,6 @@ def write_documents( document_objects = self._handle_duplicate_documents( documents=documents, - index=self.index, policy=policy, ) @@ -468,7 +467,6 @@ def get_documents_generator( def get_documents_by_id( self, ids: List[str], - index: Optional[str] = None, ) -> List[Document]: """ Retrieves documents from Qdrant by their IDs. @@ -480,13 +478,11 @@ def get_documents_by_id( :returns: A list of documents. """ - index = index or self.index - documents: List[Document] = [] ids = [convert_id(_id) for _id in ids] records = self.client.retrieve( - collection_name=index, + collection_name=self.index, ids=ids, with_payload=True, with_vectors=True, @@ -987,7 +983,6 @@ def recreate_collection( def _handle_duplicate_documents( self, documents: List[Document], - index: Optional[str] = None, policy: DuplicatePolicy = None, ): """ @@ -995,31 +990,28 @@ def _handle_duplicate_documents( documents that are not in the index yet. :param documents: A list of Haystack Document objects. - :param index: name of the index :param policy: The duplicate policy to use when writing documents. :returns: A list of Haystack Document objects. """ - index = index or self.index if policy in (DuplicatePolicy.SKIP, DuplicatePolicy.FAIL): - documents = self._drop_duplicate_documents(documents, index) - documents_found = self.get_documents_by_id(ids=[doc.id for doc in documents], index=index) + documents = self._drop_duplicate_documents(documents) + documents_found = self.get_documents_by_id(ids=[doc.id for doc in documents]) ids_exist_in_db: List[str] = [doc.id for doc in documents_found] if len(ids_exist_in_db) > 0 and policy == DuplicatePolicy.FAIL: - msg = f"Document with ids '{', '.join(ids_exist_in_db)} already exists in index = '{index}'." + msg = f"Document with ids '{', '.join(ids_exist_in_db)} already exists in index = '{self.index}'." raise DuplicateDocumentError(msg) documents = list(filter(lambda doc: doc.id not in ids_exist_in_db, documents)) return documents - def _drop_duplicate_documents(self, documents: List[Document], index: Optional[str] = None) -> List[Document]: + def _drop_duplicate_documents(self, documents: List[Document]) -> List[Document]: """ Drop duplicate documents based on same hash ID. :param documents: A list of Haystack Document objects. - :param index: Name of the index. :returns: A list of Haystack Document objects. """ _hash_ids: Set = set() @@ -1030,7 +1022,7 @@ def _drop_duplicate_documents(self, documents: List[Document], index: Optional[s logger.info( "Duplicate Documents: Document with id '%s' already exists in index '%s'", document.id, - index or self.index, + self.index, ) continue _documents.append(document) From 6cc39e68e1bcf679b1c57b5cba9b4e917086ea86 Mon Sep 17 00:00:00 2001 From: HaystackBot Date: Tue, 29 Oct 2024 12:04:01 +0000 Subject: [PATCH 054/229] Update the changelog --- integrations/qdrant/CHANGELOG.md | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/integrations/qdrant/CHANGELOG.md b/integrations/qdrant/CHANGELOG.md index a275529f8..57dd257d5 100644 --- a/integrations/qdrant/CHANGELOG.md +++ b/integrations/qdrant/CHANGELOG.md @@ -1,5 +1,12 @@ # Changelog +## [integrations/qdrant-v7.0.0] - 2024-10-29 + +### ⚙️ Miscellaneous Tasks + +- Update ruff linting scripts and settings (#1105) +- Adopt uv as installer (#1142) + ## [integrations/qdrant-v6.0.0] - 2024-09-13 ## [integrations/qdrant-v5.1.0] - 2024-09-12 @@ -105,8 +112,6 @@ - Fix haystack-ai pin (#649) - - ## [integrations/qdrant-v3.2.0] - 2024-03-27 ### 🚀 Features @@ -117,15 +122,11 @@ ### 🐛 Bug Fixes - Fix linter errors (#282) - - - Fix order of API docs (#447) This PR will also push the docs to Readme - Fixes (#518) - - ### 🚜 Refactor - [**breaking**] Qdrant - update secret management (#405) @@ -156,8 +157,6 @@ This PR will also push the docs to Readme - Fix import paths for beta5 (#237) - - ### 🚜 Refactor - Use `hatch_vcs` to manage integrations versioning (#103) From 701a790ebd997673400486a8b332bb1465635fa8 Mon Sep 17 00:00:00 2001 From: tstadel <60758086+tstadel@users.noreply.github.com> Date: Tue, 29 Oct 2024 17:49:29 +0100 Subject: [PATCH 055/229] feat: efficient knn filtering support for OpenSearch (#1134) * feat: efficient filtering support for OpenSearch * add hint about supported knn engines to docstring * Apply suggestions from code review --- .../opensearch/embedding_retriever.py | 12 +++++ .../opensearch/document_store.py | 8 +++- .../opensearch/tests/test_document_store.py | 44 +++++++++++++++++++ .../tests/test_embedding_retriever.py | 15 ++++++- 4 files changed, 76 insertions(+), 3 deletions(-) diff --git a/integrations/opensearch/src/haystack_integrations/components/retrievers/opensearch/embedding_retriever.py b/integrations/opensearch/src/haystack_integrations/components/retrievers/opensearch/embedding_retriever.py index e159634cf..1e9bb9132 100644 --- a/integrations/opensearch/src/haystack_integrations/components/retrievers/opensearch/embedding_retriever.py +++ b/integrations/opensearch/src/haystack_integrations/components/retrievers/opensearch/embedding_retriever.py @@ -31,6 +31,7 @@ def __init__( filter_policy: Union[str, FilterPolicy] = FilterPolicy.REPLACE, custom_query: Optional[Dict[str, Any]] = None, raise_on_failure: bool = True, + efficient_filtering: bool = False, ): """ Create the OpenSearchEmbeddingRetriever component. @@ -85,6 +86,8 @@ def __init__( :param raise_on_failure: If `True`, raises an exception if the API call fails. If `False`, logs a warning and returns an empty list. + :param efficient_filtering: If `True`, the filter will be applied during the approximate kNN search. + This is only supported for knn engines "faiss" and "lucene" and does not work with the default "nmslib". :raises ValueError: If `document_store` is not an instance of OpenSearchDocumentStore. """ @@ -100,6 +103,7 @@ def __init__( ) self._custom_query = custom_query self._raise_on_failure = raise_on_failure + self._efficient_filtering = efficient_filtering def to_dict(self) -> Dict[str, Any]: """ @@ -116,6 +120,7 @@ def to_dict(self) -> Dict[str, Any]: filter_policy=self._filter_policy.value, custom_query=self._custom_query, raise_on_failure=self._raise_on_failure, + efficient_filtering=self._efficient_filtering, ) @classmethod @@ -146,6 +151,7 @@ def run( filters: Optional[Dict[str, Any]] = None, top_k: Optional[int] = None, custom_query: Optional[Dict[str, Any]] = None, + efficient_filtering: Optional[bool] = None, ): """ Retrieve documents using a vector similarity metric. @@ -196,6 +202,9 @@ def run( ) ``` + :param efficient_filtering: If `True`, the filter will be applied during the approximate kNN search. + This is only supported for knn engines "faiss" and "lucene" and does not work with the default "nmslib". + :returns: Dictionary with key "documents" containing the retrieved Documents. - documents: List of Document similar to `query_embedding`. @@ -208,6 +217,8 @@ def run( top_k = self._top_k if custom_query is None: custom_query = self._custom_query + if efficient_filtering is None: + efficient_filtering = self._efficient_filtering docs: List[Document] = [] @@ -217,6 +228,7 @@ def run( filters=filters, top_k=top_k, custom_query=custom_query, + efficient_filtering=efficient_filtering, ) except Exception as e: if self._raise_on_failure: diff --git a/integrations/opensearch/src/haystack_integrations/document_stores/opensearch/document_store.py b/integrations/opensearch/src/haystack_integrations/document_stores/opensearch/document_store.py index 6f7a6c96e..4ec2420b3 100644 --- a/integrations/opensearch/src/haystack_integrations/document_stores/opensearch/document_store.py +++ b/integrations/opensearch/src/haystack_integrations/document_stores/opensearch/document_store.py @@ -438,6 +438,7 @@ def _embedding_retrieval( filters: Optional[Dict[str, Any]] = None, top_k: int = 10, custom_query: Optional[Dict[str, Any]] = None, + efficient_filtering: bool = False, ) -> List[Document]: """ Retrieves documents that are most similar to the query embedding using a vector similarity metric. @@ -474,6 +475,8 @@ def _embedding_retrieval( } ``` + :param efficient_filtering: If `True`, the filter will be applied during the approximate kNN search. + This is only supported for knn engines "faiss" and "lucene" and does not work with the default "nmslib". :raises ValueError: If `query_embedding` is an empty list :returns: List of Document that are most similar to `query_embedding` """ @@ -509,7 +512,10 @@ def _embedding_retrieval( } if filters: - body["query"]["bool"]["filter"] = normalize_filters(filters) + if efficient_filtering: + body["query"]["bool"]["must"][0]["knn"]["embedding"]["filter"] = normalize_filters(filters) + else: + body["query"]["bool"]["filter"] = normalize_filters(filters) body["size"] = top_k diff --git a/integrations/opensearch/tests/test_document_store.py b/integrations/opensearch/tests/test_document_store.py index 9cc4bf4ea..043f59891 100644 --- a/integrations/opensearch/tests/test_document_store.py +++ b/integrations/opensearch/tests/test_document_store.py @@ -337,6 +337,27 @@ def document_store_embedding_dim_4(self, request): yield store store.client.indices.delete(index=index, params={"ignore": [400, 404]}) + @pytest.fixture + def document_store_embedding_dim_4_faiss(self, request): + """ + This is the most basic requirement for the child class: provide + an instance of this document store so the base class can use it. + """ + hosts = ["https://localhost:9200"] + # Use a different index for each test so we can run them in parallel + index = f"{request.node.name}" + + store = OpenSearchDocumentStore( + hosts=hosts, + index=index, + http_auth=("admin", "admin"), + verify_certs=False, + embedding_dim=4, + method={"space_type": "innerproduct", "engine": "faiss", "name": "hnsw"}, + ) + yield store + store.client.indices.delete(index=index, params={"ignore": [400, 404]}) + def assert_documents_are_equal(self, received: List[Document], expected: List[Document]): """ The OpenSearchDocumentStore.filter_documents() method returns a Documents with their score set. @@ -690,6 +711,29 @@ def test_embedding_retrieval_with_filters(self, document_store_embedding_dim_4: assert len(results) == 1 assert results[0].content == "Not very similar document with meta field" + def test_embedding_retrieval_with_filters_efficient_filtering( + self, document_store_embedding_dim_4_faiss: OpenSearchDocumentStore + ): + docs = [ + Document(content="Most similar document", embedding=[1.0, 1.0, 1.0, 1.0]), + Document(content="2nd best document", embedding=[0.8, 0.8, 0.8, 1.0]), + Document( + content="Not very similar document with meta field", + embedding=[0.0, 0.8, 0.3, 0.9], + meta={"meta_field": "custom_value"}, + ), + ] + document_store_embedding_dim_4_faiss.write_documents(docs) + + filters = {"field": "meta_field", "operator": "==", "value": "custom_value"} + results = document_store_embedding_dim_4_faiss._embedding_retrieval( + query_embedding=[0.1, 0.1, 0.1, 0.1], + filters=filters, + efficient_filtering=True, + ) + assert len(results) == 1 + assert results[0].content == "Not very similar document with meta field" + def test_embedding_retrieval_pagination(self, document_store_embedding_dim_4: OpenSearchDocumentStore): """ Test that handling of pagination works as expected, when the matching documents are > 10. diff --git a/integrations/opensearch/tests/test_embedding_retriever.py b/integrations/opensearch/tests/test_embedding_retriever.py index 75c191946..84e9828ca 100644 --- a/integrations/opensearch/tests/test_embedding_retriever.py +++ b/integrations/opensearch/tests/test_embedding_retriever.py @@ -19,6 +19,7 @@ def test_init_default(): assert retriever._filters == {} assert retriever._top_k == 10 assert retriever._filter_policy == FilterPolicy.REPLACE + assert retriever._efficient_filtering is False retriever = OpenSearchEmbeddingRetriever(document_store=mock_store, filter_policy="replace") assert retriever._filter_policy == FilterPolicy.REPLACE @@ -82,6 +83,7 @@ def test_to_dict(_mock_opensearch_client): "filter_policy": "replace", "custom_query": {"some": "custom query"}, "raise_on_failure": True, + "efficient_filtering": False, }, } @@ -101,6 +103,7 @@ def test_from_dict(_mock_opensearch_client): "filter_policy": "replace", "custom_query": {"some": "custom query"}, "raise_on_failure": False, + "efficient_filtering": True, }, } retriever = OpenSearchEmbeddingRetriever.from_dict(data) @@ -110,6 +113,7 @@ def test_from_dict(_mock_opensearch_client): assert retriever._custom_query == {"some": "custom query"} assert retriever._raise_on_failure is False assert retriever._filter_policy == FilterPolicy.REPLACE + assert retriever._efficient_filtering is True # For backwards compatibility with older versions of the retriever without a filter policy data = { @@ -139,6 +143,7 @@ def test_run(): filters={}, top_k=10, custom_query=None, + efficient_filtering=False, ) assert len(res) == 1 assert len(res["documents"]) == 1 @@ -150,7 +155,11 @@ def test_run_init_params(): mock_store = Mock(spec=OpenSearchDocumentStore) mock_store._embedding_retrieval.return_value = [Document(content="Test doc", embedding=[0.1, 0.2])] retriever = OpenSearchEmbeddingRetriever( - document_store=mock_store, filters={"from": "init"}, top_k=11, custom_query="custom_query" + document_store=mock_store, + filters={"from": "init"}, + top_k=11, + custom_query="custom_query", + efficient_filtering=True, ) res = retriever.run(query_embedding=[0.5, 0.7]) mock_store._embedding_retrieval.assert_called_once_with( @@ -158,6 +167,7 @@ def test_run_init_params(): filters={"from": "init"}, top_k=11, custom_query="custom_query", + efficient_filtering=True, ) assert len(res) == 1 assert len(res["documents"]) == 1 @@ -169,12 +179,13 @@ def test_run_time_params(): mock_store = Mock(spec=OpenSearchDocumentStore) mock_store._embedding_retrieval.return_value = [Document(content="Test doc", embedding=[0.1, 0.2])] retriever = OpenSearchEmbeddingRetriever(document_store=mock_store, filters={"from": "init"}, top_k=11) - res = retriever.run(query_embedding=[0.5, 0.7], filters={"from": "run"}, top_k=9) + res = retriever.run(query_embedding=[0.5, 0.7], filters={"from": "run"}, top_k=9, efficient_filtering=True) mock_store._embedding_retrieval.assert_called_once_with( query_embedding=[0.5, 0.7], filters={"from": "run"}, top_k=9, custom_query=None, + efficient_filtering=True, ) assert len(res) == 1 assert len(res["documents"]) == 1 From e97cfee4c2b0868d96e4a4a335ec5cd7ee0f6cf1 Mon Sep 17 00:00:00 2001 From: HaystackBot Date: Tue, 29 Oct 2024 16:57:12 +0000 Subject: [PATCH 056/229] Update the changelog --- integrations/opensearch/CHANGELOG.md | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/integrations/opensearch/CHANGELOG.md b/integrations/opensearch/CHANGELOG.md index 713848915..afd8a57c2 100644 --- a/integrations/opensearch/CHANGELOG.md +++ b/integrations/opensearch/CHANGELOG.md @@ -1,6 +1,10 @@ # Changelog -## [integrations/opensearch-v1.0.0] - 2024-09-12 +## [integrations/opensearch-v1.1.0] - 2024-10-29 + +### 🚀 Features + +- Efficient knn filtering support for OpenSearch (#1134) ### 📚 Documentation @@ -13,6 +17,9 @@ ### ⚙️ Miscellaneous Tasks - OpenSearch - remove legacy filter support (#1067) +- Update changelog after removing legacy filters (#1083) +- Update ruff linting scripts and settings (#1105) +- Adopt uv as installer (#1142) ### Docs @@ -83,8 +90,6 @@ This PR will also push the docs to Readme - Fix links in docstrings (#188) - - ### 🚜 Refactor - Use `hatch_vcs` to manage integrations versioning (#103) @@ -95,15 +100,12 @@ This PR will also push the docs to Readme - Fix import and increase version (#77) - - ## [integrations/opensearch-v0.1.0] - 2023-12-04 ### 🐛 Bug Fixes - Fix license headers - ## [integrations/opensearch-v0.0.2] - 2023-11-30 ### 🚀 Features From 8db9990764dd95db2caac82de0b969bf456a7800 Mon Sep 17 00:00:00 2001 From: Stefano Fiorucci Date: Thu, 31 Oct 2024 10:47:06 +0100 Subject: [PATCH 057/229] chore: fix Vertex tests (#1163) * try fixing tests * more fixes --- integrations/google_vertex/tests/chat/test_gemini.py | 1 + integrations/google_vertex/tests/test_gemini.py | 1 + 2 files changed, 2 insertions(+) diff --git a/integrations/google_vertex/tests/chat/test_gemini.py b/integrations/google_vertex/tests/chat/test_gemini.py index 0d77bd9c6..73c99fe2f 100644 --- a/integrations/google_vertex/tests/chat/test_gemini.py +++ b/integrations/google_vertex/tests/chat/test_gemini.py @@ -170,6 +170,7 @@ def test_to_dict_with_params(_mock_vertexai_init, _mock_generative_model): "unit": {"type_": "STRING", "enum": ["celsius", "fahrenheit"]}, }, "required": ["location"], + "property_ordering": ["location", "unit"], }, } ] diff --git a/integrations/google_vertex/tests/test_gemini.py b/integrations/google_vertex/tests/test_gemini.py index b3d6dd5f5..277851224 100644 --- a/integrations/google_vertex/tests/test_gemini.py +++ b/integrations/google_vertex/tests/test_gemini.py @@ -157,6 +157,7 @@ def test_to_dict_with_params(_mock_vertexai_init, _mock_generative_model): "unit": {"type_": "STRING", "enum": ["celsius", "fahrenheit"]}, }, "required": ["location"], + "property_ordering": ["location", "unit"], }, } ] From 5c4a38544bb72be25d6161a71c6fda03799ff74f Mon Sep 17 00:00:00 2001 From: Stefano Fiorucci Date: Mon, 4 Nov 2024 06:45:24 +0100 Subject: [PATCH 058/229] pin onnxruntime (#1164) --- integrations/fastembed/pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/integrations/fastembed/pyproject.toml b/integrations/fastembed/pyproject.toml index 8686c9e7a..b9f1f6cfd 100644 --- a/integrations/fastembed/pyproject.toml +++ b/integrations/fastembed/pyproject.toml @@ -23,7 +23,7 @@ classifiers = [ "Programming Language :: Python :: Implementation :: CPython", "Programming Language :: Python :: Implementation :: PyPy", ] -dependencies = ["haystack-ai>=2.0.1", "fastembed>=0.2.5"] +dependencies = ["haystack-ai>=2.0.1", "fastembed>=0.2.5", "onnxruntime<1.20.0"] [project.urls] Source = "https://github.com/deepset-ai/haystack-core-integrations" From f51de64ec485b5771fb3b23b5054411b0f98586d Mon Sep 17 00:00:00 2001 From: Stefano Fiorucci Date: Wed, 6 Nov 2024 10:28:00 +0100 Subject: [PATCH 059/229] fix: adapt our implementation to breaking changes in Chroma 0.5.17 (#1165) * fix chroma breaking changes * improve warning * better warning --- integrations/chroma/pyproject.toml | 2 +- .../document_stores/chroma/document_store.py | 7 +- .../document_stores/chroma/filters.py | 8 +-- .../chroma/tests/test_document_store.py | 65 ++++++++++++++++++- 4 files changed, 74 insertions(+), 8 deletions(-) diff --git a/integrations/chroma/pyproject.toml b/integrations/chroma/pyproject.toml index 7f0943a30..cfe7a606e 100644 --- a/integrations/chroma/pyproject.toml +++ b/integrations/chroma/pyproject.toml @@ -22,7 +22,7 @@ classifiers = [ "Programming Language :: Python :: Implementation :: CPython", "Programming Language :: Python :: Implementation :: PyPy", ] -dependencies = ["haystack-ai", "chromadb>=0.5.0", "typing_extensions>=4.8.0"] +dependencies = ["haystack-ai", "chromadb>=0.5.17", "typing_extensions>=4.8.0"] [project.urls] Documentation = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/chroma#readme" diff --git a/integrations/chroma/src/haystack_integrations/document_stores/chroma/document_store.py b/integrations/chroma/src/haystack_integrations/document_stores/chroma/document_store.py index 6a83937a4..439e4b144 100644 --- a/integrations/chroma/src/haystack_integrations/document_stores/chroma/document_store.py +++ b/integrations/chroma/src/haystack_integrations/document_stores/chroma/document_store.py @@ -248,9 +248,12 @@ def write_documents(self, documents: List[Document], policy: DuplicatePolicy = D if doc.content is None: logger.warning( - "ChromaDocumentStore can only store the text field of Documents: " - "'array', 'dataframe' and 'blob' will be dropped." + "ChromaDocumentStore cannot store documents with `content=None`. " + "`array`, `dataframe` and `blob` are not supported. " + "Document with id %s will be skipped.", + doc.id, ) + continue data = {"ids": [doc.id], "documents": [doc.content]} if doc.meta: diff --git a/integrations/chroma/src/haystack_integrations/document_stores/chroma/filters.py b/integrations/chroma/src/haystack_integrations/document_stores/chroma/filters.py index 60046b6ad..df49da673 100644 --- a/integrations/chroma/src/haystack_integrations/document_stores/chroma/filters.py +++ b/integrations/chroma/src/haystack_integrations/document_stores/chroma/filters.py @@ -1,6 +1,6 @@ from collections import defaultdict from dataclasses import dataclass -from typing import Any, Dict, List +from typing import Any, Dict, List, Optional from chromadb.api.types import validate_where, validate_where_document @@ -34,8 +34,8 @@ class ChromaFilter: """ ids: List[str] - where: Dict[str, Any] - where_document: Dict[str, Any] + where: Optional[Dict[str, Any]] + where_document: Optional[Dict[str, Any]] def _convert_filters(filters: Dict[str, Any]) -> ChromaFilter: @@ -80,7 +80,7 @@ def _convert_filters(filters: Dict[str, Any]) -> ChromaFilter: msg = f"Invalid '{test_clause}' : {e}" raise ChromaDocumentStoreFilterError(msg) from e - return ChromaFilter(ids=ids, where=where, where_document=where_document) + return ChromaFilter(ids=ids, where=where or None, where_document=where_document or None) def _convert_filter_clause(filters: Dict[str, Any]) -> Dict[str, Any]: diff --git a/integrations/chroma/tests/test_document_store.py b/integrations/chroma/tests/test_document_store.py index 987f6d8b7..ed815251e 100644 --- a/integrations/chroma/tests/test_document_store.py +++ b/integrations/chroma/tests/test_document_store.py @@ -13,9 +13,12 @@ from chromadb.api.types import Documents, EmbeddingFunction, Embeddings from haystack import Document from haystack.testing.document_store import ( + TEST_EMBEDDING_1, + TEST_EMBEDDING_2, CountDocumentsTest, DeleteDocumentsTest, FilterDocumentsTest, + _random_embeddings, ) from haystack_integrations.document_stores.chroma import ChromaDocumentStore @@ -51,6 +54,67 @@ def document_store(self) -> ChromaDocumentStore: get_func.return_value = _TestEmbeddingFunction() return ChromaDocumentStore(embedding_function="test_function", collection_name=str(uuid.uuid1())) + @pytest.fixture + def filterable_docs(self) -> List[Document]: + """ + This fixture has been copied from haystack/testing/document_store.py and modified to + remove the documents that don't have textual content, as Chroma does not support writing them. + """ + documents = [] + for i in range(3): + documents.append( + Document( + content=f"A Foo Document {i}", + meta={ + "name": f"name_{i}", + "page": "100", + "chapter": "intro", + "number": 2, + "date": "1969-07-21T20:17:40", + }, + embedding=_random_embeddings(768), + ) + ) + documents.append( + Document( + content=f"A Bar Document {i}", + meta={ + "name": f"name_{i}", + "page": "123", + "chapter": "abstract", + "number": -2, + "date": "1972-12-11T19:54:58", + }, + embedding=_random_embeddings(768), + ) + ) + documents.append( + Document( + content=f"A Foobar Document {i}", + meta={ + "name": f"name_{i}", + "page": "90", + "chapter": "conclusion", + "number": -10, + "date": "1989-11-09T17:53:00", + }, + embedding=_random_embeddings(768), + ) + ) + documents.append( + Document( + content=f"Document {i} without embedding", + meta={"name": f"name_{i}", "no_embedding": True, "chapter": "conclusion"}, + ) + ) + documents.append( + Document(content=f"Doc {i} with zeros emb", meta={"name": "zeros_doc"}, embedding=TEST_EMBEDDING_1) + ) + documents.append( + Document(content=f"Doc {i} with ones emb", meta={"name": "ones_doc"}, embedding=TEST_EMBEDDING_2) + ) + return documents + def assert_documents_are_equal(self, received: List[Document], expected: List[Document]): """ Assert that two lists of Documents are equal. @@ -283,7 +347,6 @@ def test_contains(self, document_store: ChromaDocumentStore, filterable_docs: Li ) def test_multiple_contains(self, document_store: ChromaDocumentStore, filterable_docs: List[Document]): - filterable_docs = [doc for doc in filterable_docs if doc.content] # remove documents without content document_store.write_documents(filterable_docs) filters = { "operator": "OR", From 41c768d275163f48fbd997f6f5664f31e8572267 Mon Sep 17 00:00:00 2001 From: HaystackBot Date: Wed, 6 Nov 2024 09:32:25 +0000 Subject: [PATCH 060/229] Update the changelog --- integrations/chroma/CHANGELOG.md | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/integrations/chroma/CHANGELOG.md b/integrations/chroma/CHANGELOG.md index c129d00ae..591c0ec39 100644 --- a/integrations/chroma/CHANGELOG.md +++ b/integrations/chroma/CHANGELOG.md @@ -1,5 +1,16 @@ # Changelog +## [integrations/chroma-v1.0.0] - 2024-11-06 + +### 🐛 Bug Fixes + +- Fixing Chroma tests due `chromadb` update behaviour change (#1148) +- Adapt our implementation to breaking changes in Chroma 0.5.17 (#1165) + +### ⚙️ Miscellaneous Tasks + +- Adopt uv as installer (#1142) + ## [integrations/chroma-v0.22.1] - 2024-09-30 ### Chroma From 06d77cce287ca0e2fa441907230bb1a813818da5 Mon Sep 17 00:00:00 2001 From: Amna Mubashar Date: Thu, 7 Nov 2024 11:05:56 +0100 Subject: [PATCH 061/229] feat: Add Azure AI Search integration (#1122) * Azure AI Search Document Store --- .github/workflows/azure_ai_search.yml | 72 +++ integrations/azure_ai_search/LICENSE | 201 ++++++++ integrations/azure_ai_search/README.md | 26 ++ .../azure_ai_search/example/document_store.py | 44 ++ .../example/embedding_retrieval.py | 58 +++ integrations/azure_ai_search/pydoc/config.yml | 31 ++ integrations/azure_ai_search/pyproject.toml | 163 +++++++ .../retrievers/azure_ai_search/__init__.py | 3 + .../azure_ai_search/embedding_retriever.py | 116 +++++ .../azure_ai_search/__init__.py | 7 + .../azure_ai_search/document_store.py | 440 ++++++++++++++++++ .../document_stores/azure_ai_search/errors.py | 20 + .../azure_ai_search/filters.py | 112 +++++ .../azure_ai_search/tests/__init__.py | 3 + .../azure_ai_search/tests/conftest.py | 68 +++ .../tests/test_document_store.py | 410 ++++++++++++++++ .../tests/test_embedding_retriever.py | 145 ++++++ 17 files changed, 1919 insertions(+) create mode 100644 .github/workflows/azure_ai_search.yml create mode 100644 integrations/azure_ai_search/LICENSE create mode 100644 integrations/azure_ai_search/README.md create mode 100644 integrations/azure_ai_search/example/document_store.py create mode 100644 integrations/azure_ai_search/example/embedding_retrieval.py create mode 100644 integrations/azure_ai_search/pydoc/config.yml create mode 100644 integrations/azure_ai_search/pyproject.toml create mode 100644 integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/__init__.py create mode 100644 integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/embedding_retriever.py create mode 100644 integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/__init__.py create mode 100644 integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/document_store.py create mode 100644 integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/errors.py create mode 100644 integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/filters.py create mode 100644 integrations/azure_ai_search/tests/__init__.py create mode 100644 integrations/azure_ai_search/tests/conftest.py create mode 100644 integrations/azure_ai_search/tests/test_document_store.py create mode 100644 integrations/azure_ai_search/tests/test_embedding_retriever.py diff --git a/.github/workflows/azure_ai_search.yml b/.github/workflows/azure_ai_search.yml new file mode 100644 index 000000000..1c10edc91 --- /dev/null +++ b/.github/workflows/azure_ai_search.yml @@ -0,0 +1,72 @@ +# This workflow comes from https://github.com/ofek/hatch-mypyc +# https://github.com/ofek/hatch-mypyc/blob/5a198c0ba8660494d02716cfc9d79ce4adfb1442/.github/workflows/test.yml +name: Test / azure_ai_search + +on: + schedule: + - cron: "0 0 * * *" + pull_request: + paths: + - "integrations/azure_ai_search/**" + - ".github/workflows/azure_ai_search.yml" + +concurrency: + group: azure_ai_search-${{ github.head_ref }} + cancel-in-progress: true + +env: + PYTHONUNBUFFERED: "1" + FORCE_COLOR: "1" + AZURE_SEARCH_API_KEY: ${{ secrets.AZURE_SEARCH_API_KEY }} + AZURE_SEARCH_SERVICE_ENDPOINT: ${{ secrets.AZURE_SEARCH_SERVICE_ENDPOINT }} + +defaults: + run: + working-directory: integrations/azure_ai_search + +jobs: + run: + name: Python ${{ matrix.python-version }} on ${{ startsWith(matrix.os, 'macos-') && 'macOS' || startsWith(matrix.os, 'windows-') && 'Windows' || 'Linux' }} + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + max-parallel: 3 + matrix: + os: [ubuntu-latest, windows-latest] + python-version: ["3.9", "3.10", "3.11"] + + steps: + - uses: actions/checkout@v4 + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + + - name: Install Hatch + run: pip install --upgrade hatch + + - name: Lint + if: matrix.python-version == '3.9' + run: hatch run lint:all + + - name: Generate docs + if: matrix.python-version == '3.9' && runner.os == 'Linux' + run: hatch run docs + + - name: Run tests + run: hatch run cov-retry + + - name: Nightly - run unit tests with Haystack main branch + if: github.event_name == 'schedule' + run: | + hatch run pip install git+https://github.com/deepset-ai/haystack.git + hatch run cov-retry -m "not integration" + + - name: Send event to Datadog for nightly failures + if: failure() && github.event_name == 'schedule' + uses: ./.github/actions/send_failure + with: + title: | + Core integrations nightly tests failure: ${{ github.workflow }} + api-key: ${{ secrets.CORE_DATADOG_API_KEY }} diff --git a/integrations/azure_ai_search/LICENSE b/integrations/azure_ai_search/LICENSE new file mode 100644 index 000000000..de4c7f39f --- /dev/null +++ b/integrations/azure_ai_search/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright 2023 deepset GmbH + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/integrations/azure_ai_search/README.md b/integrations/azure_ai_search/README.md new file mode 100644 index 000000000..915a23b63 --- /dev/null +++ b/integrations/azure_ai_search/README.md @@ -0,0 +1,26 @@ +# Azure AI Search Document Store for Haystack + +[![PyPI - Version](https://img.shields.io/pypi/v/azure-ai-search-haystack.svg)](https://pypi.org/project/azure-ai-search-haystack) +[![PyPI - Python Version](https://img.shields.io/pypi/pyversions/azure-ai-search-haystack.svg)](https://pypi.org/project/azure-ai-search-haystack) + +----- + +**Table of Contents** + +- [Azure AI Search Document Store for Haystack](#azure-ai-search-document-store-for-haystack) + - [Installation](#installation) + - [Examples](#examples) + - [License](#license) + +## Installation + +```console +pip install azure-ai-search-haystack +``` + +## Examples +You can find a code example showing how to use the Document Store and the Retriever in the documentation or in [this Colab](https://colab.research.google.com/drive/1YpDetI8BRbObPDEVdfqUcwhEX9UUXP-m?usp=sharing). + +## License + +`azure-ai-search-haystack` is distributed under the terms of the [Apache-2.0](https://spdx.org/licenses/Apache-2.0.html) license. diff --git a/integrations/azure_ai_search/example/document_store.py b/integrations/azure_ai_search/example/document_store.py new file mode 100644 index 000000000..779f28935 --- /dev/null +++ b/integrations/azure_ai_search/example/document_store.py @@ -0,0 +1,44 @@ +from haystack import Document +from haystack.document_stores.types import DuplicatePolicy + +from haystack_integrations.document_stores.azure_ai_search import AzureAISearchDocumentStore + +""" +This example demonstrates how to use the AzureAISearchDocumentStore to write and filter documents. +To run this example, you'll need an Azure Search service endpoint and API key, which can either be +set as environment variables (AZURE_SEARCH_SERVICE_ENDPOINT and AZURE_SEARCH_API_KEY) or +provided directly to AzureAISearchDocumentStore(as params "api_key", "azure_endpoint"). +Otherwise you can use DefaultAzureCredential to authenticate with Azure services. +See more details at https://learn.microsoft.com/en-us/azure/search/keyless-connections?tabs=python%2Cazure-cli +""" +document_store = AzureAISearchDocumentStore( + metadata_fields={"version": float, "label": str}, + index_name="document-store-example", +) + +documents = [ + Document( + content="This is an introduction to using Python for data analysis.", + meta={"version": 1.0, "label": "chapter_one"}, + ), + Document( + content="Learn how to use Python libraries for machine learning.", + meta={"version": 1.5, "label": "chapter_two"}, + ), + Document( + content="Advanced Python techniques for data visualization.", + meta={"version": 2.0, "label": "chapter_three"}, + ), +] +document_store.write_documents(documents, policy=DuplicatePolicy.SKIP) + +filters = { + "operator": "AND", + "conditions": [ + {"field": "meta.version", "operator": ">", "value": 1.2}, + {"field": "meta.label", "operator": "in", "value": ["chapter_one", "chapter_three"]}, + ], +} + +results = document_store.filter_documents(filters) +print(results) diff --git a/integrations/azure_ai_search/example/embedding_retrieval.py b/integrations/azure_ai_search/example/embedding_retrieval.py new file mode 100644 index 000000000..088b08653 --- /dev/null +++ b/integrations/azure_ai_search/example/embedding_retrieval.py @@ -0,0 +1,58 @@ +from haystack import Document, Pipeline +from haystack.components.embedders import SentenceTransformersDocumentEmbedder, SentenceTransformersTextEmbedder +from haystack.components.writers import DocumentWriter +from haystack.document_stores.types import DuplicatePolicy + +from haystack_integrations.components.retrievers.azure_ai_search import AzureAISearchEmbeddingRetriever +from haystack_integrations.document_stores.azure_ai_search import AzureAISearchDocumentStore + +""" +This example demonstrates how to use the AzureAISearchEmbeddingRetriever to retrieve documents +using embeddings based on a query. To run this example, you'll need an Azure Search service endpoint +and API key, which can either be +set as environment variables (AZURE_SEARCH_SERVICE_ENDPOINT and AZURE_SEARCH_API_KEY) or +provided directly to AzureAISearchDocumentStore(as params "api_key", "azure_endpoint"). +Otherwise you can use DefaultAzureCredential to authenticate with Azure services. +See more details at https://learn.microsoft.com/en-us/azure/search/keyless-connections?tabs=python%2Cazure-cli +""" + +document_store = AzureAISearchDocumentStore(index_name="retrieval-example") + +model = "sentence-transformers/all-mpnet-base-v2" + +documents = [ + Document(content="There are over 7,000 languages spoken around the world today."), + Document( + content="""Elephants have been observed to behave in a way that indicates a + high level of self-awareness, such as recognizing themselves in mirrors.""" + ), + Document( + content="""In certain parts of the world, like the Maldives, Puerto Rico, and + San Diego, you can witness the phenomenon of bioluminescent waves.""" + ), +] + +document_embedder = SentenceTransformersDocumentEmbedder(model=model) +document_embedder.warm_up() + +# Indexing Pipeline +indexing_pipeline = Pipeline() +indexing_pipeline.add_component(instance=document_embedder, name="doc_embedder") +indexing_pipeline.add_component( + instance=DocumentWriter(document_store=document_store, policy=DuplicatePolicy.SKIP), name="doc_writer" +) +indexing_pipeline.connect("doc_embedder", "doc_writer") + +indexing_pipeline.run({"doc_embedder": {"documents": documents}}) + +# Query Pipeline +query_pipeline = Pipeline() +query_pipeline.add_component("text_embedder", SentenceTransformersTextEmbedder(model=model)) +query_pipeline.add_component("retriever", AzureAISearchEmbeddingRetriever(document_store=document_store)) +query_pipeline.connect("text_embedder.embedding", "retriever.query_embedding") + +query = "How many languages are there?" + +result = query_pipeline.run({"text_embedder": {"text": query}}) + +print(result["retriever"]["documents"][0]) diff --git a/integrations/azure_ai_search/pydoc/config.yml b/integrations/azure_ai_search/pydoc/config.yml new file mode 100644 index 000000000..ec411af60 --- /dev/null +++ b/integrations/azure_ai_search/pydoc/config.yml @@ -0,0 +1,31 @@ +loaders: + - type: haystack_pydoc_tools.loaders.CustomPythonLoader + search_path: [../src] + modules: [ + "haystack_integrations.components.retrievers.azure_ai_search.embedding_retriever", + "haystack_integrations.document_stores.azure_ai_search.document_store", + "haystack_integrations.document_stores.azure_ai_search.filters", + ] + ignore_when_discovered: ["__init__"] +processors: + - type: filter + expression: + documented_only: true + do_not_filter_modules: false + skip_empty_modules: true + - type: smart + - type: crossref +renderer: + type: haystack_pydoc_tools.renderers.ReadmeIntegrationRenderer + excerpt: Azure AI Search integration for Haystack + category_slug: integrations-api + title: Azure AI Search + slug: integrations-azure_ai_search + order: 180 + markdown: + descriptive_class_title: false + classdef_code_block: false + descriptive_module_title: true + add_method_class_prefix: true + add_member_class_prefix: false + filename: _readme_azure_ai_search.md diff --git a/integrations/azure_ai_search/pyproject.toml b/integrations/azure_ai_search/pyproject.toml new file mode 100644 index 000000000..49ca623e7 --- /dev/null +++ b/integrations/azure_ai_search/pyproject.toml @@ -0,0 +1,163 @@ +[build-system] +requires = ["hatchling", "hatch-vcs"] +build-backend = "hatchling.build" + +[project] +name = "azure-ai-search-haystack" +dynamic = ["version"] +description = 'Haystack 2.x Document Store for Azure AI Search' +readme = "README.md" +requires-python = ">=3.8,<3.13" +license = "Apache-2.0" +keywords = [] +authors = [{ name = "deepset", email = "info@deepset.ai" }] +classifiers = [ + "License :: OSI Approved :: Apache Software License", + "Development Status :: 4 - Beta", + "Programming Language :: Python", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: Implementation :: CPython", + "Programming Language :: Python :: Implementation :: PyPy", +] +dependencies = ["haystack-ai", "azure-search-documents>=11.5", "azure-identity"] + +[project.urls] +Documentation = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/azure_ai_search#readme" +Issues = "https://github.com/deepset-ai/haystack-core-integrations/issues" +Source = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/azure_ai_search" + +[tool.hatch.build.targets.wheel] +packages = ["src/haystack_integrations"] + +[tool.hatch.version] +source = "vcs" +tag-pattern = 'integrations\/azure-ai-search-v(?P.*)' + +[tool.hatch.version.raw-options] +root = "../.." +git_describe_command = 'git describe --tags --match="integrations/azure-ai-search-v[0-9]*"' + +[tool.hatch.envs.default] +dependencies = [ + "coverage[toml]>=6.5", + "pytest", + "pytest-rerunfailures", + "pytest-xdist", + "haystack-pydoc-tools", +] + +[tool.hatch.envs.default.scripts] +test = "pytest {args:tests}" +test-cov = "coverage run -m pytest {args:tests}" +test-cov-retry = "test-cov --reruns 3 --reruns-delay 30 -x" +cov-report = ["- coverage combine", "coverage report"] +cov = ["test-cov", "cov-report"] +cov-retry = ["test-cov-retry", "cov-report"] +docs = ["pydoc-markdown pydoc/config.yml"] + +[[tool.hatch.envs.all.matrix]] +python = ["3.8", "3.9", "3.10", "3.11"] + +[tool.hatch.envs.lint] +detached = true +dependencies = ["black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243"] +[tool.hatch.envs.lint.scripts] +typing = "mypy --install-types --non-interactive --explicit-package-bases {args:src/ tests}" +style = ["ruff check {args:src/}", "black --check --diff {args:.}"] +fmt = ["black {args:.}", "ruff check --fix {args:.}", "style"] +all = ["style", "typing"] + +[tool.hatch.metadata] +allow-direct-references = true + +[tool.black] +target-version = ["py38"] +line-length = 120 +skip-string-normalization = true + +[tool.ruff] +target-version = "py38" +line-length = 120 + +[tool.ruff.lint] +select = [ + "A", + "ARG", + "B", + "C", + "DTZ", + "E", + "EM", + "F", + "FBT", + "I", + "ICN", + "ISC", + "N", + "PLC", + "PLE", + "PLR", + "PLW", + "Q", + "RUF", + "S", + "T", + "TID", + "UP", + "W", + "YTT", +] +ignore = [ + # Allow non-abstract empty methods in abstract base classes + "B027", + # Allow boolean positional values in function calls, like `dict.get(... True)` + "FBT003", + # Ignore checks for possible passwords + "S105", + "S106", + "S107", + # Ignore complexity + "C901", + "PLR0911", + "PLR0912", + "PLR0913", + "PLR0915", +] +unfixable = [ + # Don't touch unused imports + "F401", +] +exclude = ["example"] + +[tool.ruff.lint.isort] +known-first-party = ["src"] + +[tool.ruff.lint.flake8-tidy-imports] +ban-relative-imports = "parents" + +[tool.ruff.lint.per-file-ignores] +# Tests can use magic values, assertions, and relative imports +"tests/**/*" = ["PLR2004", "S101", "TID252", "S311"] +"example/**/*" = ["T201"] + +[tool.coverage.run] +source = ["haystack_integrations"] +branch = true +parallel = false + + +[tool.coverage.report] +omit = ["*/tests/*", "*/__init__.py"] +show_missing = true +exclude_lines = ["no cov", "if __name__ == .__main__.:", "if TYPE_CHECKING:"] + + +[tool.pytest.ini_options] +minversion = "6.0" +markers = ["unit: unit tests", "integration: integration tests"] + +[[tool.mypy.overrides]] +module = ["haystack.*", "haystack_integrations.*", "pytest.*", "azure.identity.*", "mypy.*", "azure.core.*", "azure.search.documents.*"] +ignore_missing_imports = true \ No newline at end of file diff --git a/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/__init__.py b/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/__init__.py new file mode 100644 index 000000000..eb75ffa6c --- /dev/null +++ b/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/__init__.py @@ -0,0 +1,3 @@ +from .embedding_retriever import AzureAISearchEmbeddingRetriever + +__all__ = ["AzureAISearchEmbeddingRetriever"] diff --git a/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/embedding_retriever.py b/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/embedding_retriever.py new file mode 100644 index 000000000..ab649f874 --- /dev/null +++ b/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/embedding_retriever.py @@ -0,0 +1,116 @@ +import logging +from typing import Any, Dict, List, Optional, Union + +from haystack import Document, component, default_from_dict, default_to_dict +from haystack.document_stores.types import FilterPolicy +from haystack.document_stores.types.filter_policy import apply_filter_policy + +from haystack_integrations.document_stores.azure_ai_search import AzureAISearchDocumentStore, normalize_filters + +logger = logging.getLogger(__name__) + + +@component +class AzureAISearchEmbeddingRetriever: + """ + Retrieves documents from the AzureAISearchDocumentStore using a vector similarity metric. + Must be connected to the AzureAISearchDocumentStore to run. + + """ + + def __init__( + self, + *, + document_store: AzureAISearchDocumentStore, + filters: Optional[Dict[str, Any]] = None, + top_k: int = 10, + filter_policy: Union[str, FilterPolicy] = FilterPolicy.REPLACE, + ): + """ + Create the AzureAISearchEmbeddingRetriever component. + + :param document_store: An instance of AzureAISearchDocumentStore to use with the Retriever. + :param filters: Filters applied when fetching documents from the Document Store. + Filters are applied during the approximate kNN search to ensure the Retriever returns + `top_k` matching documents. + :param top_k: Maximum number of documents to return. + :filter_policy: Policy to determine how filters are applied. Possible options: + + """ + self._filters = filters or {} + self._top_k = top_k + self._document_store = document_store + self._filter_policy = ( + filter_policy if isinstance(filter_policy, FilterPolicy) else FilterPolicy.from_str(filter_policy) + ) + + if not isinstance(document_store, AzureAISearchDocumentStore): + message = "document_store must be an instance of AzureAISearchDocumentStore" + raise Exception(message) + + def to_dict(self) -> Dict[str, Any]: + """ + Serializes the component to a dictionary. + + :returns: + Dictionary with serialized data. + """ + return default_to_dict( + self, + filters=self._filters, + top_k=self._top_k, + document_store=self._document_store.to_dict(), + filter_policy=self._filter_policy.value, + ) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "AzureAISearchEmbeddingRetriever": + """ + Deserializes the component from a dictionary. + + :param data: + Dictionary to deserialize from. + + :returns: + Deserialized component. + """ + data["init_parameters"]["document_store"] = AzureAISearchDocumentStore.from_dict( + data["init_parameters"]["document_store"] + ) + + # Pipelines serialized with old versions of the component might not + # have the filter_policy field. + if "filter_policy" in data["init_parameters"]: + data["init_parameters"]["filter_policy"] = FilterPolicy.from_str(data["init_parameters"]["filter_policy"]) + return default_from_dict(cls, data) + + @component.output_types(documents=List[Document]) + def run(self, query_embedding: List[float], filters: Optional[Dict[str, Any]] = None, top_k: Optional[int] = None): + """Retrieve documents from the AzureAISearchDocumentStore. + + :param query_embedding: floats representing the query embedding + :param filters: Filters applied to the retrieved Documents. The way runtime filters are applied depends on + the `filter_policy` chosen at retriever initialization. See init method docstring for more + details. + :param top_k: the maximum number of documents to retrieve. + :returns: a dictionary with the following keys: + - `documents`: A list of documents retrieved from the AzureAISearchDocumentStore. + """ + + top_k = top_k or self._top_k + if filters is not None: + applied_filters = apply_filter_policy(self._filter_policy, self._filters, filters) + normalized_filters = normalize_filters(applied_filters) + else: + normalized_filters = "" + + try: + docs = self._document_store._embedding_retrieval( + query_embedding=query_embedding, + filters=normalized_filters, + top_k=top_k, + ) + except Exception as e: + raise e + + return {"documents": docs} diff --git a/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/__init__.py b/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/__init__.py new file mode 100644 index 000000000..635878a38 --- /dev/null +++ b/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/__init__.py @@ -0,0 +1,7 @@ +# SPDX-FileCopyrightText: 2023-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 +from .document_store import DEFAULT_VECTOR_SEARCH, AzureAISearchDocumentStore +from .filters import normalize_filters + +__all__ = ["AzureAISearchDocumentStore", "DEFAULT_VECTOR_SEARCH", "normalize_filters"] diff --git a/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/document_store.py b/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/document_store.py new file mode 100644 index 000000000..0b59b6e37 --- /dev/null +++ b/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/document_store.py @@ -0,0 +1,440 @@ +# SPDX-FileCopyrightText: 2023-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 +import logging +import os +from dataclasses import asdict +from datetime import datetime +from typing import Any, Dict, List, Optional + +from azure.core.credentials import AzureKeyCredential +from azure.core.exceptions import ClientAuthenticationError, HttpResponseError, ResourceNotFoundError +from azure.identity import DefaultAzureCredential +from azure.search.documents import SearchClient +from azure.search.documents.indexes import SearchIndexClient +from azure.search.documents.indexes.models import ( + HnswAlgorithmConfiguration, + HnswParameters, + SearchableField, + SearchField, + SearchFieldDataType, + SearchIndex, + SimpleField, + VectorSearch, + VectorSearchAlgorithmMetric, + VectorSearchProfile, +) +from azure.search.documents.models import VectorizedQuery +from haystack import default_from_dict, default_to_dict +from haystack.dataclasses import Document +from haystack.document_stores.types import DuplicatePolicy +from haystack.utils import Secret, deserialize_secrets_inplace + +from .errors import AzureAISearchDocumentStoreConfigError +from .filters import normalize_filters + +type_mapping = { + str: "Edm.String", + bool: "Edm.Boolean", + int: "Edm.Int32", + float: "Edm.Double", + datetime: "Edm.DateTimeOffset", +} + +DEFAULT_VECTOR_SEARCH = VectorSearch( + profiles=[ + VectorSearchProfile(name="default-vector-config", algorithm_configuration_name="cosine-algorithm-config") + ], + algorithms=[ + HnswAlgorithmConfiguration( + name="cosine-algorithm-config", + parameters=HnswParameters( + metric=VectorSearchAlgorithmMetric.COSINE, + ), + ) + ], +) + +logger = logging.getLogger(__name__) +logging.getLogger("azure").setLevel(logging.ERROR) +logging.getLogger("azure.identity").setLevel(logging.DEBUG) + + +class AzureAISearchDocumentStore: + def __init__( + self, + *, + api_key: Secret = Secret.from_env_var("AZURE_SEARCH_API_KEY", strict=False), # noqa: B008 + azure_endpoint: Secret = Secret.from_env_var("AZURE_SEARCH_SERVICE_ENDPOINT", strict=True), # noqa: B008 + index_name: str = "default", + embedding_dimension: int = 768, + metadata_fields: Optional[Dict[str, type]] = None, + vector_search_configuration: VectorSearch = None, + **kwargs, + ): + """ + A document store using [Azure AI Search](https://azure.microsoft.com/products/ai-services/ai-search/) + as the backend. + + :param azure_endpoint: The URL endpoint of an Azure AI Search service. + :param api_key: The API key to use for authentication. + :param index_name: Name of index in Azure AI Search, if it doesn't exist it will be created. + :param embedding_dimension: Dimension of the embeddings. + :param metadata_fields: A dictionary of metadata keys and their types to create + additional fields in index schema. As fields in Azure SearchIndex cannot be dynamic, + it is necessary to specify the metadata fields in advance. + (e.g. metadata_fields = {"author": str, "date": datetime}) + :param vector_search_configuration: Configuration option related to vector search. + Default configuration uses the HNSW algorithm with cosine similarity to handle vector searches. + + :param kwargs: Optional keyword parameters for Azure AI Search. + Some of the supported parameters: + - `api_version`: The Search API version to use for requests. + - `audience`: sets the Audience to use for authentication with Azure Active Directory (AAD). + The audience is not considered when using a shared key. If audience is not provided, + the public cloud audience will be assumed. + + For more information on parameters, see the [official Azure AI Search documentation](https://learn.microsoft.com/en-us/azure/search/) + """ + + azure_endpoint = azure_endpoint or os.environ.get("AZURE_SEARCH_SERVICE_ENDPOINT") or None + if not azure_endpoint: + msg = "Please provide an Azure endpoint or set the environment variable AZURE_OPENAI_ENDPOINT." + raise ValueError(msg) + + api_key = api_key or os.environ.get("AZURE_SEARCH_API_KEY") or None + + self._client = None + self._index_client = None + self._index_fields = [] # type: List[Any] # stores all fields in the final schema of index + self._api_key = api_key + self._azure_endpoint = azure_endpoint + self._index_name = index_name + self._embedding_dimension = embedding_dimension + self._dummy_vector = [-10.0] * self._embedding_dimension + self._metadata_fields = metadata_fields + self._vector_search_configuration = vector_search_configuration or DEFAULT_VECTOR_SEARCH + self._kwargs = kwargs + + @property + def client(self) -> SearchClient: + + # resolve secrets for authentication + resolved_endpoint = ( + self._azure_endpoint.resolve_value() if isinstance(self._azure_endpoint, Secret) else self._azure_endpoint + ) + resolved_key = self._api_key.resolve_value() if isinstance(self._api_key, Secret) else self._api_key + + credential = AzureKeyCredential(resolved_key) if resolved_key else DefaultAzureCredential() + try: + if not self._index_client: + self._index_client = SearchIndexClient(resolved_endpoint, credential, **self._kwargs) + if not self._index_exists(self._index_name): + # Create a new index if it does not exist + logger.debug( + "The index '%s' does not exist. A new index will be created.", + self._index_name, + ) + self._create_index(self._index_name) + except (HttpResponseError, ClientAuthenticationError) as error: + msg = f"Failed to authenticate with Azure Search: {error}" + raise AzureAISearchDocumentStoreConfigError(msg) from error + + if self._index_client: + # Get the search client, if index client is initialized + index_fields = self._index_client.get_index(self._index_name).fields + self._index_fields = [field.name for field in index_fields] + self._client = self._index_client.get_search_client(self._index_name) + else: + msg = "Search Index Client is not initialized." + raise AzureAISearchDocumentStoreConfigError(msg) + + return self._client + + def _create_index(self, index_name: str, **kwargs) -> None: + """ + Creates a new search index. + :param index_name: Name of the index to create. If None, the index name from the constructor is used. + :param kwargs: Optional keyword parameters. + """ + + # default fields to create index based on Haystack Document (id, content, embedding) + default_fields = [ + SimpleField(name="id", type=SearchFieldDataType.String, key=True, filterable=True), + SearchableField(name="content", type=SearchFieldDataType.String), + SearchField( + name="embedding", + type=SearchFieldDataType.Collection(SearchFieldDataType.Single), + searchable=True, + hidden=False, + vector_search_dimensions=self._embedding_dimension, + vector_search_profile_name="default-vector-config", + ), + ] + + if not index_name: + index_name = self._index_name + if self._metadata_fields: + default_fields.extend(self._create_metadata_index_fields(self._metadata_fields)) + index = SearchIndex( + name=index_name, fields=default_fields, vector_search=self._vector_search_configuration, **kwargs + ) + if self._index_client: + self._index_client.create_index(index) + + def to_dict(self) -> Dict[str, Any]: + # This is not the best solution to serialise this class but is the fastest to implement. + # Not all kwargs types can be serialised to text so this can fail. We must serialise each + # type explicitly to handle this properly. + """ + Serializes the component to a dictionary. + + :returns: + Dictionary with serialized data. + """ + return default_to_dict( + self, + azure_endpoint=self._azure_endpoint.to_dict() if self._azure_endpoint is not None else None, + api_key=self._api_key.to_dict() if self._api_key is not None else None, + index_name=self._index_name, + embedding_dimension=self._embedding_dimension, + metadata_fields=self._metadata_fields, + vector_search_configuration=self._vector_search_configuration.as_dict(), + **self._kwargs, + ) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "AzureAISearchDocumentStore": + """ + Deserializes the component from a dictionary. + + :param data: + Dictionary to deserialize from. + + :returns: + Deserialized component. + """ + + deserialize_secrets_inplace(data["init_parameters"], keys=["api_key", "azure_endpoint"]) + if (vector_search_configuration := data["init_parameters"].get("vector_search_configuration")) is not None: + data["init_parameters"]["vector_search_configuration"] = VectorSearch.from_dict(vector_search_configuration) + return default_from_dict(cls, data) + + def count_documents(self) -> int: + """ + Returns how many documents are present in the search index. + + :returns: list of retrieved documents. + """ + return self.client.get_document_count() + + def write_documents(self, documents: List[Document], policy: DuplicatePolicy = DuplicatePolicy.NONE) -> int: + """ + Writes the provided documents to search index. + + :param documents: documents to write to the index. + :return: the number of documents added to index. + """ + + def _convert_input_document(documents: Document): + document_dict = asdict(documents) + if not isinstance(document_dict["id"], str): + msg = f"Document id {document_dict['id']} is not a string, " + raise Exception(msg) + index_document = self._convert_haystack_documents_to_azure(document_dict) + + return index_document + + if len(documents) > 0: + if not isinstance(documents[0], Document): + msg = "param 'documents' must contain a list of objects of type Document" + raise ValueError(msg) + + if policy not in [DuplicatePolicy.NONE, DuplicatePolicy.OVERWRITE]: + logger.warning( + f"AzureAISearchDocumentStore only supports `DuplicatePolicy.OVERWRITE`" + f"but got {policy}. Overwriting duplicates is enabled by default." + ) + client = self.client + documents_to_write = [(_convert_input_document(doc)) for doc in documents] + + if documents_to_write != []: + client.upload_documents(documents_to_write) + return len(documents_to_write) + + def delete_documents(self, document_ids: List[str]) -> None: + """ + Deletes all documents with a matching document_ids from the search index. + + :param document_ids: ids of the documents to be deleted. + """ + if self.count_documents() == 0: + return + documents = self._get_raw_documents_by_id(document_ids) + if documents: + self.client.delete_documents(documents) + + def get_documents_by_id(self, document_ids: List[str]) -> List[Document]: + return self._convert_search_result_to_documents(self._get_raw_documents_by_id(document_ids)) + + def search_documents(self, search_text: str = "*", top_k: int = 10) -> List[Document]: + """ + Returns all documents that match the provided search_text. + If search_text is None, returns all documents. + :param search_text: the text to search for in the Document list. + :param top_k: Maximum number of documents to return. + :returns: A list of Documents that match the given search_text. + """ + result = self.client.search(search_text=search_text, top=top_k) + return self._convert_search_result_to_documents(list(result)) + + def filter_documents(self, filters: Optional[Dict[str, Any]] = None) -> List[Document]: + """ + Returns the documents that match the provided filters. + Filters should be given as a dictionary supporting filtering by metadata. For details on + filters, see the [metadata filtering documentation](https://docs.haystack.deepset.ai/docs/metadata-filtering). + + :param filters: the filters to apply to the document list. + :returns: A list of Documents that match the given filters. + """ + if filters: + normalized_filters = normalize_filters(filters) + result = self.client.search(filter=normalized_filters) + return self._convert_search_result_to_documents(result) + else: + return self.search_documents() + + def _convert_search_result_to_documents(self, azure_docs: List[Dict[str, Any]]) -> List[Document]: + """ + Converts Azure search results to Haystack Documents. + """ + documents = [] + + for azure_doc in azure_docs: + embedding = azure_doc.get("embedding") + if embedding == self._dummy_vector: + embedding = None + + # Anything besides default fields (id, content, and embedding) is considered metadata + meta = { + key: value + for key, value in azure_doc.items() + if key not in ["id", "content", "embedding"] and key in self._index_fields and value is not None + } + + # Create the document with meta only if it's non-empty + doc = Document( + id=azure_doc["id"], content=azure_doc["content"], embedding=embedding, meta=meta if meta else {} + ) + + documents.append(doc) + return documents + + def _index_exists(self, index_name: Optional[str]) -> bool: + """ + Check if the index exists in the Azure AI Search service. + + :param index_name: The name of the index to check. + :returns bool: whether the index exists. + """ + + if self._index_client and index_name: + return index_name in self._index_client.list_index_names() + else: + msg = "Index name is required to check if the index exists." + raise ValueError(msg) + + def _get_raw_documents_by_id(self, document_ids: List[str]): + """ + Retrieves all Azure documents with a matching document_ids from the document store. + + :param document_ids: ids of the documents to be retrieved. + :returns: list of retrieved Azure documents. + """ + azure_documents = [] + for doc_id in document_ids: + try: + document = self.client.get_document(doc_id) + azure_documents.append(document) + except ResourceNotFoundError: + logger.warning(f"Document with ID {doc_id} not found.") + return azure_documents + + def _convert_haystack_documents_to_azure(self, document: Dict[str, Any]) -> Dict[str, Any]: + """Map the document keys to fields of search index""" + + # Because Azure Search does not allow dynamic fields, we only include fields that are part of the schema + index_document = {k: v for k, v in {**document, **document.get("meta", {})}.items() if k in self._index_fields} + if index_document["embedding"] is None: + index_document["embedding"] = self._dummy_vector + + return index_document + + def _create_metadata_index_fields(self, metadata: Dict[str, Any]) -> List[SimpleField]: + """Create a list of index fields for storing metadata values.""" + + index_fields = [] + metadata_field_mapping = self._map_metadata_field_types(metadata) + + for key, field_type in metadata_field_mapping.items(): + index_fields.append(SimpleField(name=key, type=field_type, filterable=True)) + + return index_fields + + def _map_metadata_field_types(self, metadata: Dict[str, type]) -> Dict[str, str]: + """Map metadata field types to Azure Search field types.""" + + metadata_field_mapping = {} + + for key, value_type in metadata.items(): + + if not key[0].isalpha(): + msg = ( + f"Azure Search index only allows field names starting with letters. " + f"Invalid key: {key} will be dropped." + ) + logger.warning(msg) + continue + + field_type = type_mapping.get(value_type) + if not field_type: + error_message = f"Unsupported field type for key '{key}': {value_type}" + raise ValueError(error_message) + metadata_field_mapping[key] = field_type + + return metadata_field_mapping + + def _embedding_retrieval( + self, + query_embedding: List[float], + *, + top_k: int = 10, + fields: Optional[List[str]] = None, + filters: Optional[Dict[str, Any]] = None, + ) -> List[Document]: + """ + Retrieves documents that are most similar to the query embedding using a vector similarity metric. + It uses the vector configuration of the document store. By default it uses the HNSW algorithm + with cosine similarity. + + This method is not meant to be part of the public interface of + `AzureAISearchDocumentStore` nor called directly. + `AzureAISearchEmbeddingRetriever` uses this method directly and is the public interface for it. + + :param query_embedding: Embedding of the query. + :param filters: Filters applied to the retrieved Documents. Defaults to None. + Filters are applied during the approximate kNN search to ensure that top_k matching documents are returned. + :param top_k: Maximum number of Documents to return, defaults to 10 + + :raises ValueError: If `query_embedding` is an empty list + :returns: List of Document that are most similar to `query_embedding` + """ + + if not query_embedding: + msg = "query_embedding must be a non-empty list of floats" + raise ValueError(msg) + + vector_query = VectorizedQuery(vector=query_embedding, k_nearest_neighbors=top_k, fields="embedding") + result = self.client.search(search_text=None, vector_queries=[vector_query], select=fields, filter=filters) + azure_docs = list(result) + return self._convert_search_result_to_documents(azure_docs) diff --git a/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/errors.py b/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/errors.py new file mode 100644 index 000000000..0fbc80696 --- /dev/null +++ b/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/errors.py @@ -0,0 +1,20 @@ +from haystack.document_stores.errors import DocumentStoreError +from haystack.errors import FilterError + + +class AzureAISearchDocumentStoreError(DocumentStoreError): + """Parent class for all AzureAISearchDocumentStore exceptions.""" + + pass + + +class AzureAISearchDocumentStoreConfigError(AzureAISearchDocumentStoreError): + """Raised when a configuration is not valid for a AzureAISearchDocumentStore.""" + + pass + + +class AzureAISearchDocumentStoreFilterError(FilterError): + """Raised when filter is not valid for AzureAISearchDocumentStore.""" + + pass diff --git a/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/filters.py b/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/filters.py new file mode 100644 index 000000000..650e3f8be --- /dev/null +++ b/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/filters.py @@ -0,0 +1,112 @@ +from typing import Any, Dict + +from dateutil import parser + +from .errors import AzureAISearchDocumentStoreFilterError + +LOGICAL_OPERATORS = {"AND": "and", "OR": "or", "NOT": "not"} + + +def normalize_filters(filters: Dict[str, Any]) -> str: + """ + Converts Haystack filters in Azure AI Search compatible filters. + """ + if not isinstance(filters, dict): + msg = """Filters must be a dictionary. + See https://docs.haystack.deepset.ai/docs/metadata-filtering for details on filters syntax.""" + raise AzureAISearchDocumentStoreFilterError(msg) + + if "field" in filters: + return _parse_comparison_condition(filters) + return _parse_logical_condition(filters) + + +def _parse_logical_condition(condition: Dict[str, Any]) -> str: + missing_keys = [key for key in ("operator", "conditions") if key not in condition] + if missing_keys: + msg = f"""Missing key(s) {missing_keys} in {condition}. + See https://docs.haystack.deepset.ai/docs/metadata-filtering for details on filters syntax.""" + raise AzureAISearchDocumentStoreFilterError(msg) + + operator = condition["operator"] + if operator not in LOGICAL_OPERATORS: + msg = f"Unknown operator {operator}" + raise AzureAISearchDocumentStoreFilterError(msg) + conditions = [] + for c in condition["conditions"]: + # Recursively parse if the condition itself is a logical condition + if isinstance(c, dict) and "operator" in c and c["operator"] in LOGICAL_OPERATORS: + conditions.append(_parse_logical_condition(c)) + else: + # Otherwise, parse it as a comparison condition + conditions.append(_parse_comparison_condition(c)) + + # Format the result based on the operator + if operator == "NOT": + return f"not ({' and '.join([f'({c})' for c in conditions])})" + else: + return f" {LOGICAL_OPERATORS[operator]} ".join([f"({c})" for c in conditions]) + + +def _parse_comparison_condition(condition: Dict[str, Any]) -> str: + missing_keys = [key for key in ("field", "operator", "value") if key not in condition] + if missing_keys: + msg = f"""Missing key(s) {missing_keys} in {condition}. + See https://docs.haystack.deepset.ai/docs/metadata-filtering for details on filters syntax.""" + raise AzureAISearchDocumentStoreFilterError(msg) + + # Remove the "meta." prefix from the field name if present + field = condition["field"][5:] if condition["field"].startswith("meta.") else condition["field"] + operator = condition["operator"] + value = "null" if condition["value"] is None else condition["value"] + + if operator not in COMPARISON_OPERATORS: + msg = f"Unknown operator {operator}. Valid operators are: {list(COMPARISON_OPERATORS.keys())}" + raise AzureAISearchDocumentStoreFilterError(msg) + + return COMPARISON_OPERATORS[operator](field, value) + + +def _eq(field: str, value: Any) -> str: + return f"{field} eq '{value}'" if isinstance(value, str) and value != "null" else f"{field} eq {value}" + + +def _ne(field: str, value: Any) -> str: + return f"not ({field} eq '{value}')" if isinstance(value, str) and value != "null" else f"not ({field} eq {value})" + + +def _in(field: str, value: Any) -> str: + if not isinstance(value, list) or any(not isinstance(v, str) for v in value): + msg = "Azure AI Search only supports a list of strings for 'in' comparators" + raise AzureAISearchDocumentStoreFilterError(msg) + values = ", ".join(map(str, value)) + return f"search.in({field},'{values}')" + + +def _comparison_operator(field: str, value: Any, operator: str) -> str: + _validate_type(value, operator) + return f"{field} {operator} {value}" + + +def _validate_type(value: Any, operator: str) -> None: + """Validates that the value is either an integer, float, or ISO 8601 string.""" + msg = f"Invalid value type for '{operator}' comparator. Supported types are: int, float, or ISO 8601 string." + + if isinstance(value, str): + try: + parser.isoparse(value) + except ValueError as e: + raise AzureAISearchDocumentStoreFilterError(msg) from e + elif not isinstance(value, (int, float)): + raise AzureAISearchDocumentStoreFilterError(msg) + + +COMPARISON_OPERATORS = { + "==": _eq, + "!=": _ne, + "in": _in, + ">": lambda f, v: _comparison_operator(f, v, "gt"), + ">=": lambda f, v: _comparison_operator(f, v, "ge"), + "<": lambda f, v: _comparison_operator(f, v, "lt"), + "<=": lambda f, v: _comparison_operator(f, v, "le"), +} diff --git a/integrations/azure_ai_search/tests/__init__.py b/integrations/azure_ai_search/tests/__init__.py new file mode 100644 index 000000000..e873bc332 --- /dev/null +++ b/integrations/azure_ai_search/tests/__init__.py @@ -0,0 +1,3 @@ +# SPDX-FileCopyrightText: 2023-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 diff --git a/integrations/azure_ai_search/tests/conftest.py b/integrations/azure_ai_search/tests/conftest.py new file mode 100644 index 000000000..3017c79c2 --- /dev/null +++ b/integrations/azure_ai_search/tests/conftest.py @@ -0,0 +1,68 @@ +import os +import time +import uuid + +import pytest +from azure.core.credentials import AzureKeyCredential +from azure.core.exceptions import ResourceNotFoundError +from azure.search.documents.indexes import SearchIndexClient +from haystack.document_stores.types import DuplicatePolicy + +from haystack_integrations.document_stores.azure_ai_search import AzureAISearchDocumentStore + +# This is the approximate time in seconds it takes for the documents to be available in Azure Search index +SLEEP_TIME_IN_SECONDS = 5 + + +@pytest.fixture() +def sleep_time(): + return SLEEP_TIME_IN_SECONDS + + +@pytest.fixture +def document_store(request): + """ + This is the most basic requirement for the child class: provide + an instance of this document store so the base class can use it. + """ + index_name = f"haystack_test_{uuid.uuid4().hex}" + metadata_fields = getattr(request, "param", {}).get("metadata_fields", None) + + azure_endpoint = os.environ["AZURE_SEARCH_SERVICE_ENDPOINT"] + api_key = os.environ["AZURE_SEARCH_API_KEY"] + + client = SearchIndexClient(azure_endpoint, AzureKeyCredential(api_key)) + if index_name in client.list_index_names(): + client.delete_index(index_name) + + store = AzureAISearchDocumentStore( + api_key=api_key, + azure_endpoint=azure_endpoint, + index_name=index_name, + create_index=True, + embedding_dimension=768, + metadata_fields=metadata_fields, + ) + + # Override some methods to wait for the documents to be available + original_write_documents = store.write_documents + + def write_documents_and_wait(documents, policy=DuplicatePolicy.OVERWRITE): + written_docs = original_write_documents(documents, policy) + time.sleep(SLEEP_TIME_IN_SECONDS) + return written_docs + + original_delete_documents = store.delete_documents + + def delete_documents_and_wait(filters): + original_delete_documents(filters) + time.sleep(SLEEP_TIME_IN_SECONDS) + + store.write_documents = write_documents_and_wait + store.delete_documents = delete_documents_and_wait + + yield store + try: + client.delete_index(index_name) + except ResourceNotFoundError: + pass diff --git a/integrations/azure_ai_search/tests/test_document_store.py b/integrations/azure_ai_search/tests/test_document_store.py new file mode 100644 index 000000000..1bcd967c6 --- /dev/null +++ b/integrations/azure_ai_search/tests/test_document_store.py @@ -0,0 +1,410 @@ +# SPDX-FileCopyrightText: 2023-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 +import os +import random +from datetime import datetime, timezone +from typing import List +from unittest.mock import patch + +import pytest +from haystack.dataclasses.document import Document +from haystack.errors import FilterError +from haystack.testing.document_store import ( + CountDocumentsTest, + DeleteDocumentsTest, + FilterDocumentsTest, + WriteDocumentsTest, +) +from haystack.utils.auth import EnvVarSecret, Secret + +from haystack_integrations.document_stores.azure_ai_search import DEFAULT_VECTOR_SEARCH, AzureAISearchDocumentStore + + +@patch("haystack_integrations.document_stores.azure_ai_search.document_store.AzureAISearchDocumentStore") +def test_to_dict(monkeypatch): + monkeypatch.setenv("AZURE_SEARCH_API_KEY", "test-api-key") + monkeypatch.setenv("AZURE_SEARCH_SERVICE_ENDPOINT", "test-endpoint") + document_store = AzureAISearchDocumentStore() + res = document_store.to_dict() + assert res == { + "type": "haystack_integrations.document_stores.azure_ai_search.document_store.AzureAISearchDocumentStore", + "init_parameters": { + "azure_endpoint": {"env_vars": ["AZURE_SEARCH_SERVICE_ENDPOINT"], "strict": True, "type": "env_var"}, + "api_key": {"env_vars": ["AZURE_SEARCH_API_KEY"], "strict": False, "type": "env_var"}, + "index_name": "default", + "embedding_dimension": 768, + "metadata_fields": None, + "vector_search_configuration": { + "profiles": [ + {"name": "default-vector-config", "algorithm_configuration_name": "cosine-algorithm-config"} + ], + "algorithms": [ + { + "name": "cosine-algorithm-config", + "kind": "hnsw", + "parameters": {"m": 4, "ef_construction": 400, "ef_search": 500, "metric": "cosine"}, + } + ], + }, + }, + } + + +@patch("haystack_integrations.document_stores.azure_ai_search.document_store.AzureAISearchDocumentStore") +def test_from_dict(monkeypatch): + monkeypatch.setenv("AZURE_SEARCH_API_KEY", "test-api-key") + monkeypatch.setenv("AZURE_SEARCH_SERVICE_ENDPOINT", "test-endpoint") + + data = { + "type": "haystack_integrations.document_stores.azure_ai_search.document_store.AzureAISearchDocumentStore", + "init_parameters": { + "azure_endpoint": {"env_vars": ["AZURE_SEARCH_SERVICE_ENDPOINT"], "strict": True, "type": "env_var"}, + "api_key": {"env_vars": ["AZURE_SEARCH_API_KEY"], "strict": False, "type": "env_var"}, + "embedding_dimension": 768, + "index_name": "default", + "metadata_fields": None, + "vector_search_configuration": DEFAULT_VECTOR_SEARCH, + }, + } + document_store = AzureAISearchDocumentStore.from_dict(data) + assert isinstance(document_store._api_key, EnvVarSecret) + assert isinstance(document_store._azure_endpoint, EnvVarSecret) + assert document_store._index_name == "default" + assert document_store._embedding_dimension == 768 + assert document_store._metadata_fields is None + assert document_store._vector_search_configuration == DEFAULT_VECTOR_SEARCH + + +@patch("haystack_integrations.document_stores.azure_ai_search.document_store.AzureAISearchDocumentStore") +def test_init_is_lazy(_mock_azure_search_client): + AzureAISearchDocumentStore(azure_endpoint=Secret.from_token("test_endpoint")) + _mock_azure_search_client.assert_not_called() + + +@patch("haystack_integrations.document_stores.azure_ai_search.document_store.AzureAISearchDocumentStore") +def test_init(_mock_azure_search_client): + + document_store = AzureAISearchDocumentStore( + api_key=Secret.from_token("fake-api-key"), + azure_endpoint=Secret.from_token("fake_endpoint"), + index_name="my_index", + embedding_dimension=15, + metadata_fields={"Title": str, "Pages": int}, + ) + + assert document_store._index_name == "my_index" + assert document_store._embedding_dimension == 15 + assert document_store._metadata_fields == {"Title": str, "Pages": int} + assert document_store._vector_search_configuration == DEFAULT_VECTOR_SEARCH + + +@pytest.mark.skipif( + not os.environ.get("AZURE_SEARCH_SERVICE_ENDPOINT", None) and not os.environ.get("AZURE_SEARCH_API_KEY", None), + reason="Missing AZURE_SEARCH_SERVICE_ENDPOINT or AZURE_SEARCH_API_KEY.", +) +class TestDocumentStore(CountDocumentsTest, WriteDocumentsTest, DeleteDocumentsTest): + + def test_write_documents(self, document_store: AzureAISearchDocumentStore): + docs = [Document(id="1")] + assert document_store.write_documents(docs) == 1 + + # Parametrize the test with metadata fields + @pytest.mark.parametrize( + "document_store", + [ + {"metadata_fields": {"author": str, "publication_year": int, "rating": float}}, + ], + indirect=True, + ) + def test_write_documents_with_meta(self, document_store: AzureAISearchDocumentStore): + docs = [ + Document( + id="1", + meta={"author": "Tom", "publication_year": 2021, "rating": 4.5}, + content="This is a test document.", + ) + ] + document_store.write_documents(docs) + doc = document_store.get_documents_by_id(["1"]) + assert doc[0] == docs[0] + + @pytest.mark.skip(reason="Azure AI search index overwrites duplicate documents by default") + def test_write_documents_duplicate_fail(self, document_store: AzureAISearchDocumentStore): ... + + @pytest.mark.skip(reason="Azure AI search index overwrites duplicate documents by default") + def test_write_documents_duplicate_skip(self, document_store: AzureAISearchDocumentStore): ... + + +def _random_embeddings(n): + return [round(random.random(), 7) for _ in range(n)] # nosec: S311 + + +TEST_EMBEDDING_1 = _random_embeddings(768) +TEST_EMBEDDING_2 = _random_embeddings(768) + + +@pytest.mark.skipif( + not os.environ.get("AZURE_SEARCH_SERVICE_ENDPOINT", None) and not os.environ.get("AZURE_SEARCH_API_KEY", None), + reason="Missing AZURE_SEARCH_SERVICE_ENDPOINT or AZURE_SEARCH_API_KEY.", +) +@pytest.mark.parametrize( + "document_store", + [ + {"metadata_fields": {"name": str, "page": str, "chapter": str, "number": int, "date": datetime}}, + ], + indirect=True, +) +class TestFilters(FilterDocumentsTest): + + # Overriding to change "date" to compatible ISO 8601 format + # and remove incompatible fields (dataframes) for Azure search index + @pytest.fixture + def filterable_docs(self) -> List[Document]: + """Fixture that returns a list of Documents that can be used to test filtering.""" + documents = [] + for i in range(3): + documents.append( + Document( + content=f"A Foo Document {i}", + meta={ + "name": f"name_{i}", + "page": "100", + "chapter": "intro", + "number": 2, + "date": "1969-07-21T20:17:40Z", + }, + embedding=_random_embeddings(768), + ) + ) + documents.append( + Document( + content=f"A Bar Document {i}", + meta={ + "name": f"name_{i}", + "page": "123", + "chapter": "abstract", + "number": -2, + "date": "1972-12-11T19:54:58Z", + }, + embedding=_random_embeddings(768), + ) + ) + documents.append( + Document( + content=f"A Foobar Document {i}", + meta={ + "name": f"name_{i}", + "page": "90", + "chapter": "conclusion", + "number": -10, + "date": "1989-11-09T17:53:00Z", + }, + embedding=_random_embeddings(768), + ) + ) + + documents.append( + Document(content=f"Doc {i} with zeros emb", meta={"name": "zeros_doc"}, embedding=TEST_EMBEDDING_1) + ) + documents.append( + Document(content=f"Doc {i} with ones emb", meta={"name": "ones_doc"}, embedding=TEST_EMBEDDING_2) + ) + return documents + + # Overriding to compare the documents with the same order + def assert_documents_are_equal(self, received: List[Document], expected: List[Document]): + """ + Assert that two lists of Documents are equal. + + This is used in every test, if a Document Store implementation has a different behaviour + it should override this method. This can happen for example when the Document Store sets + a score to returned Documents. Since we can't know what the score will be, we can't compare + the Documents reliably. + """ + sorted_recieved = sorted(received, key=lambda doc: doc.id) + sorted_expected = sorted(expected, key=lambda doc: doc.id) + assert sorted_recieved == sorted_expected + + @pytest.mark.skip(reason="Azure AI search index does not support dataframes") + def test_comparison_equal_with_dataframe(self, document_store, filterable_docs): ... + + @pytest.mark.skip(reason="Azure AI search index does not support dataframes") + def test_comparison_not_equal_with_dataframe(self, document_store, filterable_docs): ... + + @pytest.mark.skip(reason="Azure AI search index does not support dataframes") + def test_comparison_greater_than_with_dataframe(self, document_store, filterable_docs): ... + + @pytest.mark.skip(reason="Azure AI search index does not support dataframes") + def test_comparison_less_than_with_dataframe(self, document_store, filterable_docs): ... + + @pytest.mark.skip(reason="Azure AI search index does not support dataframes") + def test_comparison_greater_than_equal_with_dataframe(self, document_store, filterable_docs): ... + + @pytest.mark.skip(reason="Azure AI search index does not support dataframes") + def test_comparison_less_than_equal_with_dataframe(self, document_store, filterable_docs): ... + + # Azure search index supports UTC datetime in ISO 8601 format + def test_comparison_greater_than_with_iso_date(self, document_store, filterable_docs): + """Test filter_documents() with > comparator and datetime""" + document_store.write_documents(filterable_docs) + result = document_store.filter_documents( + {"field": "meta.date", "operator": ">", "value": "1972-12-11T19:54:58Z"} + ) + self.assert_documents_are_equal( + result, + [ + d + for d in filterable_docs + if d.meta.get("date") is not None + and datetime.strptime(d.meta["date"], "%Y-%m-%dT%H:%M:%SZ").replace(tzinfo=timezone.utc) + > datetime.strptime("1972-12-11T19:54:58Z", "%Y-%m-%dT%H:%M:%SZ").replace(tzinfo=timezone.utc) + ], + ) + + def test_comparison_greater_than_equal_with_iso_date(self, document_store, filterable_docs): + """Test filter_documents() with >= comparator and datetime""" + document_store.write_documents(filterable_docs) + result = document_store.filter_documents( + {"field": "meta.date", "operator": ">=", "value": "1969-07-21T20:17:40Z"} + ) + self.assert_documents_are_equal( + result, + [ + d + for d in filterable_docs + if d.meta.get("date") is not None + and datetime.strptime(d.meta["date"], "%Y-%m-%dT%H:%M:%SZ").replace(tzinfo=timezone.utc) + >= datetime.strptime("1969-07-21T20:17:40Z", "%Y-%m-%dT%H:%M:%SZ").replace(tzinfo=timezone.utc) + ], + ) + + def test_comparison_less_than_with_iso_date(self, document_store, filterable_docs): + """Test filter_documents() with < comparator and datetime""" + document_store.write_documents(filterable_docs) + result = document_store.filter_documents( + {"field": "meta.date", "operator": "<", "value": "1969-07-21T20:17:40Z"} + ) + self.assert_documents_are_equal( + result, + [ + d + for d in filterable_docs + if d.meta.get("date") is not None + and datetime.strptime(d.meta["date"], "%Y-%m-%dT%H:%M:%SZ").replace(tzinfo=timezone.utc) + < datetime.strptime("1969-07-21T20:17:40Z", "%Y-%m-%dT%H:%M:%SZ").replace(tzinfo=timezone.utc) + ], + ) + + def test_comparison_less_than_equal_with_iso_date(self, document_store, filterable_docs): + """Test filter_documents() with <= comparator and datetime""" + document_store.write_documents(filterable_docs) + result = document_store.filter_documents( + {"field": "meta.date", "operator": "<=", "value": "1969-07-21T20:17:40Z"} + ) + self.assert_documents_are_equal( + result, + [ + d + for d in filterable_docs + if d.meta.get("date") is not None + and datetime.strptime(d.meta["date"], "%Y-%m-%dT%H:%M:%SZ").replace(tzinfo=timezone.utc) + <= datetime.strptime("1969-07-21T20:17:40Z", "%Y-%m-%dT%H:%M:%SZ").replace(tzinfo=timezone.utc) + ], + ) + + # Override as comparison operators with None/null raise errors + def test_comparison_greater_than_with_none(self, document_store, filterable_docs): + """Test filter_documents() with > comparator and None""" + document_store.write_documents(filterable_docs) + with pytest.raises(FilterError): + document_store.filter_documents(filters={"field": "meta.number", "operator": ">", "value": None}) + + def test_comparison_greater_than_equal_with_none(self, document_store, filterable_docs): + """Test filter_documents() with >= comparator and None""" + document_store.write_documents(filterable_docs) + with pytest.raises(FilterError): + document_store.filter_documents(filters={"field": "meta.number", "operator": ">=", "value": None}) + + def test_comparison_less_than_with_none(self, document_store, filterable_docs): + """Test filter_documents() with < comparator and None""" + document_store.write_documents(filterable_docs) + with pytest.raises(FilterError): + document_store.filter_documents(filters={"field": "meta.number", "operator": "<", "value": None}) + + def test_comparison_less_than_equal_with_none(self, document_store, filterable_docs): + """Test filter_documents() with <= comparator and None""" + document_store.write_documents(filterable_docs) + with pytest.raises(FilterError): + document_store.filter_documents(filters={"field": "meta.number", "operator": "<=", "value": None}) + + # Override as Azure AI Search supports 'in' operator only for strings + def test_comparison_in(self, document_store, filterable_docs): + """Test filter_documents() with 'in' comparator""" + document_store.write_documents(filterable_docs) + result = document_store.filter_documents({"field": "meta.page", "operator": "in", "value": ["100", "123"]}) + assert len(result) + expected = [d for d in filterable_docs if d.meta.get("page") is not None and d.meta["page"] in ["100", "123"]] + self.assert_documents_are_equal(result, expected) + + @pytest.mark.skip(reason="Azure AI search index does not support not in operator") + def test_comparison_not_in(self, document_store, filterable_docs): ... + + @pytest.mark.skip(reason="Azure AI search index does not support not in operator") + def test_comparison_not_in_with_with_non_list(self, document_store, filterable_docs): ... + + @pytest.mark.skip(reason="Azure AI search index does not support not in operator") + def test_comparison_not_in_with_with_non_list_iterable(self, document_store, filterable_docs): ... + + def test_missing_condition_operator_key(self, document_store, filterable_docs): + """Test filter_documents() with missing operator key""" + document_store.write_documents(filterable_docs) + with pytest.raises(FilterError): + document_store.filter_documents( + filters={"conditions": [{"field": "meta.name", "operator": "eq", "value": "test"}]} + ) + + def test_nested_logical_filters(self, document_store, filterable_docs): + document_store.write_documents(filterable_docs) + filters = { + "operator": "OR", + "conditions": [ + {"field": "meta.name", "operator": "==", "value": "name_0"}, + { + "operator": "AND", + "conditions": [ + {"field": "meta.number", "operator": "!=", "value": 0}, + {"field": "meta.page", "operator": "==", "value": "123"}, + ], + }, + { + "operator": "AND", + "conditions": [ + {"field": "meta.chapter", "operator": "==", "value": "conclusion"}, + {"field": "meta.page", "operator": "==", "value": "90"}, + ], + }, + ], + } + result = document_store.filter_documents(filters=filters) + self.assert_documents_are_equal( + result, + [ + doc + for doc in filterable_docs + if ( + # Ensure all required fields are present in doc.meta + ("name" in doc.meta and doc.meta.get("name") == "name_0") + or ( + all(key in doc.meta for key in ["number", "page"]) + and doc.meta.get("number") != 0 + and doc.meta.get("page") == "123" + ) + or ( + all(key in doc.meta for key in ["page", "chapter"]) + and doc.meta.get("chapter") == "conclusion" + and doc.meta.get("page") == "90" + ) + ) + ], + ) diff --git a/integrations/azure_ai_search/tests/test_embedding_retriever.py b/integrations/azure_ai_search/tests/test_embedding_retriever.py new file mode 100644 index 000000000..d4615ec44 --- /dev/null +++ b/integrations/azure_ai_search/tests/test_embedding_retriever.py @@ -0,0 +1,145 @@ +# SPDX-FileCopyrightText: 2023-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 +import os +from typing import List +from unittest.mock import Mock + +import pytest +from azure.core.exceptions import HttpResponseError +from haystack.dataclasses import Document +from haystack.document_stores.types import FilterPolicy +from numpy.random import rand # type: ignore + +from haystack_integrations.components.retrievers.azure_ai_search import AzureAISearchEmbeddingRetriever +from haystack_integrations.document_stores.azure_ai_search import DEFAULT_VECTOR_SEARCH, AzureAISearchDocumentStore + + +def test_init_default(): + mock_store = Mock(spec=AzureAISearchDocumentStore) + retriever = AzureAISearchEmbeddingRetriever(document_store=mock_store) + assert retriever._document_store == mock_store + assert retriever._filters == {} + assert retriever._top_k == 10 + assert retriever._filter_policy == FilterPolicy.REPLACE + + retriever = AzureAISearchEmbeddingRetriever(document_store=mock_store, filter_policy="replace") + assert retriever._filter_policy == FilterPolicy.REPLACE + + with pytest.raises(ValueError): + AzureAISearchEmbeddingRetriever(document_store=mock_store, filter_policy="unknown") + + +def test_to_dict(): + document_store = AzureAISearchDocumentStore(hosts="some fake host") + retriever = AzureAISearchEmbeddingRetriever(document_store=document_store) + res = retriever.to_dict() + assert res == { + "type": "haystack_integrations.components.retrievers.azure_ai_search.embedding_retriever.AzureAISearchEmbeddingRetriever", # noqa: E501 + "init_parameters": { + "filters": {}, + "top_k": 10, + "document_store": { + "type": "haystack_integrations.document_stores.azure_ai_search.document_store.AzureAISearchDocumentStore", # noqa: E501 + "init_parameters": { + "azure_endpoint": { + "type": "env_var", + "env_vars": ["AZURE_SEARCH_SERVICE_ENDPOINT"], + "strict": True, + }, + "api_key": {"type": "env_var", "env_vars": ["AZURE_SEARCH_API_KEY"], "strict": False}, + "index_name": "default", + "embedding_dimension": 768, + "metadata_fields": None, + "vector_search_configuration": { + "profiles": [ + {"name": "default-vector-config", "algorithm_configuration_name": "cosine-algorithm-config"} + ], + "algorithms": [ + { + "name": "cosine-algorithm-config", + "kind": "hnsw", + "parameters": {"m": 4, "ef_construction": 400, "ef_search": 500, "metric": "cosine"}, + } + ], + }, + "hosts": "some fake host", + }, + }, + "filter_policy": "replace", + }, + } + + +def test_from_dict(): + data = { + "type": "haystack_integrations.components.retrievers.azure_ai_search.embedding_retriever.AzureAISearchEmbeddingRetriever", # noqa: E501 + "init_parameters": { + "filters": {}, + "top_k": 10, + "document_store": { + "type": "haystack_integrations.document_stores.azure_ai_search.document_store.AzureAISearchDocumentStore", # noqa: E501 + "init_parameters": { + "azure_endpoint": { + "type": "env_var", + "env_vars": ["AZURE_SEARCH_SERVICE_ENDPOINT"], + "strict": True, + }, + "api_key": {"type": "env_var", "env_vars": ["AZURE_SEARCH_API_KEY"], "strict": False}, + "index_name": "default", + "embedding_dimension": 768, + "metadata_fields": None, + "vector_search_configuration": DEFAULT_VECTOR_SEARCH, + "hosts": "some fake host", + }, + }, + "filter_policy": "replace", + }, + } + retriever = AzureAISearchEmbeddingRetriever.from_dict(data) + assert isinstance(retriever._document_store, AzureAISearchDocumentStore) + assert retriever._filters == {} + assert retriever._top_k == 10 + assert retriever._filter_policy == FilterPolicy.REPLACE + + +@pytest.mark.skipif( + not os.environ.get("AZURE_SEARCH_SERVICE_ENDPOINT", None) and not os.environ.get("AZURE_SEARCH_API_KEY", None), + reason="Missing AZURE_SEARCH_SERVICE_ENDPOINT or AZURE_SEARCH_API_KEY.", +) +@pytest.mark.integration +class TestRetriever: + + def test_run(self, document_store: AzureAISearchDocumentStore): + docs = [Document(id="1")] + document_store.write_documents(docs) + retriever = AzureAISearchEmbeddingRetriever(document_store=document_store) + res = retriever.run(query_embedding=[0.1] * 768) + assert res["documents"] == docs + + def test_embedding_retrieval(self, document_store: AzureAISearchDocumentStore): + query_embedding = [0.1] * 768 + most_similar_embedding = [0.8] * 768 + second_best_embedding = [0.8] * 200 + [0.1] * 300 + [0.2] * 268 + another_embedding = rand(768).tolist() + + docs = [ + Document(content="This is first document", embedding=most_similar_embedding), + Document(content="This is second document", embedding=second_best_embedding), + Document(content="This is thrid document", embedding=another_embedding), + ] + + document_store.write_documents(docs) + retriever = AzureAISearchEmbeddingRetriever(document_store=document_store) + results = retriever.run(query_embedding=query_embedding) + assert results["documents"][0].content == "This is first document" + + def test_empty_query_embedding(self, document_store: AzureAISearchDocumentStore): + query_embedding: List[float] = [] + with pytest.raises(ValueError): + document_store._embedding_retrieval(query_embedding=query_embedding) + + def test_query_embedding_wrong_dimension(self, document_store: AzureAISearchDocumentStore): + query_embedding = [0.1] * 4 + with pytest.raises(HttpResponseError): + document_store._embedding_retrieval(query_embedding=query_embedding) From 34ae0bdd3e4486c9d515237e11f1d5966679a167 Mon Sep 17 00:00:00 2001 From: Bohan Qu Date: Sun, 10 Nov 2024 07:17:48 +0800 Subject: [PATCH 062/229] feat: add support for ttft (#1161) * feat: add support for ttft * chore: skip ttft logging if completion start time is invalid * chore: addressing lint issues --------- Co-authored-by: Vladimir Blagojevic --- .../tracing/langfuse/tracer.py | 17 ++- integrations/langfuse/tests/test_tracer.py | 101 ++++++++++++------ 2 files changed, 85 insertions(+), 33 deletions(-) diff --git a/integrations/langfuse/src/haystack_integrations/tracing/langfuse/tracer.py b/integrations/langfuse/src/haystack_integrations/tracing/langfuse/tracer.py index 94064a0d1..c9c8a354e 100644 --- a/integrations/langfuse/src/haystack_integrations/tracing/langfuse/tracer.py +++ b/integrations/langfuse/src/haystack_integrations/tracing/langfuse/tracer.py @@ -1,5 +1,7 @@ import contextlib +import logging import os +from datetime import datetime from typing import Any, Dict, Iterator, Optional, Union from haystack.components.generators.openai_utils import _convert_message_to_openai_format @@ -9,6 +11,8 @@ import langfuse +logger = logging.getLogger(__name__) + HAYSTACK_LANGFUSE_ENFORCE_FLUSH_ENV_VAR = "HAYSTACK_LANGFUSE_ENFORCE_FLUSH" _SUPPORTED_GENERATORS = [ "AzureOpenAIGenerator", @@ -148,7 +152,18 @@ def trace(self, operation_name: str, tags: Optional[Dict[str, Any]] = None) -> I replies = span._data.get("haystack.component.output", {}).get("replies") if replies: meta = replies[0].meta - span._span.update(usage=meta.get("usage") or None, model=meta.get("model")) + completion_start_time = meta.get("completion_start_time") + if completion_start_time: + try: + completion_start_time = datetime.fromisoformat(completion_start_time) + except ValueError: + logger.error(f"Failed to parse completion_start_time: {completion_start_time}") + completion_start_time = None + span._span.update( + usage=meta.get("usage") or None, + model=meta.get("model"), + completion_start_time=completion_start_time, + ) pipeline_input = tags.get("haystack.pipeline.input_data", None) if pipeline_input: diff --git a/integrations/langfuse/tests/test_tracer.py b/integrations/langfuse/tests/test_tracer.py index c6bf4acdf..9ee8e5dc4 100644 --- a/integrations/langfuse/tests/test_tracer.py +++ b/integrations/langfuse/tests/test_tracer.py @@ -1,9 +1,43 @@ -import os +import datetime from unittest.mock import MagicMock, Mock, patch +from haystack.dataclasses import ChatMessage from haystack_integrations.tracing.langfuse.tracer import LangfuseTracer +class MockSpan: + def __init__(self): + self._data = {} + self._span = self + self.operation_name = "operation_name" + + def raw_span(self): + return self + + def span(self, name=None): + # assert correct operation name passed to the span + assert name == "operation_name" + return self + + def update(self, **kwargs): + self._data.update(kwargs) + + def generation(self, name=None): + return self + + def end(self): + pass + + +class MockTracer: + + def trace(self, name, **kwargs): + return MockSpan() + + def flush(self): + pass + + class TestLangfuseTracer: # LangfuseTracer can be initialized with a Langfuse instance, a name and a boolean value for public. @@ -45,37 +79,6 @@ def test_create_new_span(self): # check that update method is called on the span instance with the provided key value pairs def test_update_span_with_pipeline_input_output_data(self): - class MockTracer: - - def trace(self, name, **kwargs): - return MockSpan() - - def flush(self): - pass - - class MockSpan: - def __init__(self): - self._data = {} - self._span = self - self.operation_name = "operation_name" - - def raw_span(self): - return self - - def span(self, name=None): - # assert correct operation name passed to the span - assert name == "operation_name" - return self - - def update(self, **kwargs): - self._data.update(kwargs) - - def generation(self, name=None): - return self - - def end(self): - pass - tracer = LangfuseTracer(tracer=MockTracer(), name="Haystack", public=False) with tracer.trace(operation_name="operation_name", tags={"haystack.pipeline.input_data": "hello"}) as span: assert span.raw_span()._data["metadata"] == {"haystack.pipeline.input_data": "hello"} @@ -83,6 +86,40 @@ def end(self): with tracer.trace(operation_name="operation_name", tags={"haystack.pipeline.output_data": "bye"}) as span: assert span.raw_span()._data["metadata"] == {"haystack.pipeline.output_data": "bye"} + def test_trace_generation(self): + tracer = LangfuseTracer(tracer=MockTracer(), name="Haystack", public=False) + tags = { + "haystack.component.type": "OpenAIChatGenerator", + "haystack.component.output": { + "replies": [ + ChatMessage.from_assistant( + "", meta={"completion_start_time": "2021-07-27T16:02:08.012345", "model": "test_model"} + ) + ] + }, + } + with tracer.trace(operation_name="operation_name", tags=tags) as span: + ... + assert span.raw_span()._data["usage"] is None + assert span.raw_span()._data["model"] == "test_model" + assert span.raw_span()._data["completion_start_time"] == datetime.datetime(2021, 7, 27, 16, 2, 8, 12345) + + def test_trace_generation_invalid_start_time(self): + tracer = LangfuseTracer(tracer=MockTracer(), name="Haystack", public=False) + tags = { + "haystack.component.type": "OpenAIChatGenerator", + "haystack.component.output": { + "replies": [ + ChatMessage.from_assistant("", meta={"completion_start_time": "foobar", "model": "test_model"}), + ] + }, + } + with tracer.trace(operation_name="operation_name", tags=tags) as span: + ... + assert span.raw_span()._data["usage"] is None + assert span.raw_span()._data["model"] == "test_model" + assert span.raw_span()._data["completion_start_time"] is None + def test_update_span_gets_flushed_by_default(self): tracer_mock = Mock() From 1bcb9a8e6a3d1ab5955dcb6fb18ec90c184ea258 Mon Sep 17 00:00:00 2001 From: Stefano Fiorucci Date: Mon, 11 Nov 2024 16:40:01 +0100 Subject: [PATCH 063/229] Weaviate - skip writing _split_overlap meta field (#1173) --- .../weaviate/document_store.py | 8 +++++++ .../weaviate/tests/test_document_store.py | 24 +++++++++++++++++++ 2 files changed, 32 insertions(+) diff --git a/integrations/weaviate/src/haystack_integrations/document_stores/weaviate/document_store.py b/integrations/weaviate/src/haystack_integrations/document_stores/weaviate/document_store.py index e312b1473..6acf0156e 100644 --- a/integrations/weaviate/src/haystack_integrations/document_stores/weaviate/document_store.py +++ b/integrations/weaviate/src/haystack_integrations/document_stores/weaviate/document_store.py @@ -286,6 +286,14 @@ def _to_data_object(self, document: Document) -> Dict[str, Any]: # The embedding vector is stored separately from the rest of the data del data["embedding"] + # _split_overlap meta field is unsupported because of a bug + # https://github.com/deepset-ai/haystack-core-integrations/issues/1172 + if "_split_overlap" in data: + data.pop("_split_overlap") + logger.warning( + "Document %s has the unsupported `_split_overlap` meta field. It will be ignored.", data["_original_id"] + ) + if "sparse_embedding" in data: sparse_embedding = data.pop("sparse_embedding", None) if sparse_embedding: diff --git a/integrations/weaviate/tests/test_document_store.py b/integrations/weaviate/tests/test_document_store.py index 70f1e1eb2..00af322e4 100644 --- a/integrations/weaviate/tests/test_document_store.py +++ b/integrations/weaviate/tests/test_document_store.py @@ -508,6 +508,30 @@ def test_comparison_less_than_equal_with_iso_date(self, document_store, filterab def test_comparison_not_equal_with_dataframe(self, document_store, filterable_docs): return super().test_comparison_not_equal_with_dataframe(document_store, filterable_docs) + def test_meta_split_overlap_is_skipped(self, document_store): + doc = Document( + content="The moonlight shimmered ", + meta={ + "source_id": "62049ba1d1e1d5ebb1f6230b0b00c5356b8706c56e0b9c36b1dfc86084cd75f0", + "page_number": 1, + "split_id": 0, + "split_idx_start": 0, + "_split_overlap": [ + {"doc_id": "68ed48ba830048c5d7815874ed2de794722e6d10866b6c55349a914fd9a0df65", "range": (0, 20)} + ], + }, + ) + document_store.write_documents([doc]) + + written_doc = document_store.filter_documents()[0] + + assert written_doc.content == "The moonlight shimmered " + assert written_doc.meta["source_id"] == "62049ba1d1e1d5ebb1f6230b0b00c5356b8706c56e0b9c36b1dfc86084cd75f0" + assert written_doc.meta["page_number"] == 1.0 + assert written_doc.meta["split_id"] == 0.0 + assert written_doc.meta["split_idx_start"] == 0.0 + assert "_split_overlap" not in written_doc.meta + def test_bm25_retrieval(self, document_store): document_store.write_documents( [ From 4cfee2de5e77691ca7c7b9754dc06d696fdb9185 Mon Sep 17 00:00:00 2001 From: HaystackBot Date: Mon, 11 Nov 2024 15:41:46 +0000 Subject: [PATCH 064/229] Update the changelog --- integrations/weaviate/CHANGELOG.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/integrations/weaviate/CHANGELOG.md b/integrations/weaviate/CHANGELOG.md index ec15cbeef..6ffe0e60b 100644 --- a/integrations/weaviate/CHANGELOG.md +++ b/integrations/weaviate/CHANGELOG.md @@ -1,5 +1,7 @@ # Changelog +## [integrations/weaviate-v4.0.1] - 2024-11-11 + ## [integrations/weaviate-v4.0.0] - 2024-10-18 ### 🐛 Bug Fixes From 946e1540386d9d03ca670e660f63c6572e1d1b5b Mon Sep 17 00:00:00 2001 From: Stefano Fiorucci Date: Tue, 12 Nov 2024 12:55:01 +0100 Subject: [PATCH 065/229] fix: `GoogleAIGeminiGenerator` - remove support for tools and change output type (#1177) * GoogleAIGeminiGenerator - rm support for tools * simplify --- .../components/generators/google_ai/gemini.py | 32 +++---- .../google_ai/tests/generators/test_gemini.py | 85 ++----------------- 2 files changed, 19 insertions(+), 98 deletions(-) diff --git a/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/gemini.py b/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/gemini.py index 218e16c4c..b032169df 100644 --- a/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/gemini.py +++ b/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/gemini.py @@ -2,7 +2,7 @@ from typing import Any, Callable, Dict, List, Optional, Union import google.generativeai as genai -from google.ai.generativelanguage import Content, Part, Tool +from google.ai.generativelanguage import Content, Part from google.generativeai import GenerationConfig, GenerativeModel from google.generativeai.types import GenerateContentResponse, HarmBlockThreshold, HarmCategory from haystack.core.component import component @@ -62,6 +62,16 @@ class GoogleAIGeminiGenerator: ``` """ + def __new__(cls, *_, **kwargs): + if "tools" in kwargs: + msg = ( + "GoogleAIGeminiGenerator does not support the `tools` parameter. " + " Use GoogleAIGeminiChatGenerator instead." + ) + raise TypeError(msg) + return super(GoogleAIGeminiGenerator, cls).__new__(cls) # noqa: UP008 + # super(__class__, cls) is needed because of the component decorator + def __init__( self, *, @@ -69,7 +79,6 @@ def __init__( model: str = "gemini-1.5-flash", generation_config: Optional[Union[GenerationConfig, Dict[str, Any]]] = None, safety_settings: Optional[Dict[HarmCategory, HarmBlockThreshold]] = None, - tools: Optional[List[Tool]] = None, streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, ): """ @@ -86,7 +95,6 @@ def __init__( :param safety_settings: The safety settings to use. A dictionary with `HarmCategory` as keys and `HarmBlockThreshold` as values. For more information, see [the API reference](https://ai.google.dev/api) - :param tools: A list of Tool objects that can be used for [Function calling](https://ai.google.dev/docs/function_calling). :param streaming_callback: A callback function that is called when a new token is received from the stream. The callback function accepts StreamingChunk as an argument. """ @@ -96,8 +104,7 @@ def __init__( self._model_name = model self._generation_config = generation_config self._safety_settings = safety_settings - self._tools = tools - self._model = GenerativeModel(self._model_name, tools=self._tools) + self._model = GenerativeModel(self._model_name) self._streaming_callback = streaming_callback def _generation_config_to_dict(self, config: Union[GenerationConfig, Dict[str, Any]]) -> Dict[str, Any]: @@ -126,11 +133,8 @@ def to_dict(self) -> Dict[str, Any]: model=self._model_name, generation_config=self._generation_config, safety_settings=self._safety_settings, - tools=self._tools, streaming_callback=callback_name, ) - if (tools := data["init_parameters"].get("tools")) is not None: - data["init_parameters"]["tools"] = [Tool.serialize(t) for t in tools] if (generation_config := data["init_parameters"].get("generation_config")) is not None: data["init_parameters"]["generation_config"] = self._generation_config_to_dict(generation_config) if (safety_settings := data["init_parameters"].get("safety_settings")) is not None: @@ -149,8 +153,6 @@ def from_dict(cls, data: Dict[str, Any]) -> "GoogleAIGeminiGenerator": """ deserialize_secrets_inplace(data["init_parameters"], keys=["api_key"]) - if (tools := data["init_parameters"].get("tools")) is not None: - data["init_parameters"]["tools"] = [Tool.deserialize(t) for t in tools] if (generation_config := data["init_parameters"].get("generation_config")) is not None: data["init_parameters"]["generation_config"] = GenerationConfig(**generation_config) if (safety_settings := data["init_parameters"].get("safety_settings")) is not None: @@ -178,7 +180,7 @@ def _convert_part(self, part: Union[str, ByteStream, Part]) -> Part: msg = f"Unsupported type {type(part)} for part {part}" raise ValueError(msg) - @component.output_types(replies=List[Union[str, Dict[str, str]]]) + @component.output_types(replies=List[str]) def run( self, parts: Variadic[Union[str, ByteStream, Part]], @@ -192,7 +194,7 @@ def run( :param streaming_callback: A callback function that is called when a new token is received from the stream. :returns: A dictionary containing the following key: - - `replies`: A list of strings or dictionaries with function calls. + - `replies`: A list of strings containing the generated responses. """ # check if streaming_callback is passed @@ -221,12 +223,6 @@ def _get_response(self, response_body: GenerateContentResponse) -> List[str]: for part in candidate.content.parts: if part.text != "": replies.append(part.text) - elif part.function_call is not None: - function_call = { - "name": part.function_call.name, - "args": dict(part.function_call.args.items()), - } - replies.append(function_call) return replies def _get_stream_response( diff --git a/integrations/google_ai/tests/generators/test_gemini.py b/integrations/google_ai/tests/generators/test_gemini.py index 7206b7a43..07d194a59 100644 --- a/integrations/google_ai/tests/generators/test_gemini.py +++ b/integrations/google_ai/tests/generators/test_gemini.py @@ -2,32 +2,12 @@ from unittest.mock import patch import pytest -from google.ai.generativelanguage import FunctionDeclaration, Tool from google.generativeai import GenerationConfig, GenerativeModel from google.generativeai.types import HarmBlockThreshold, HarmCategory from haystack.dataclasses import StreamingChunk from haystack_integrations.components.generators.google_ai import GoogleAIGeminiGenerator -GET_CURRENT_WEATHER_FUNC = FunctionDeclaration( - name="get_current_weather", - description="Get the current weather in a given location", - parameters={ - "type_": "OBJECT", - "properties": { - "location": {"type_": "STRING", "description": "The city and state, e.g. San Francisco, CA"}, - "unit": { - "type_": "STRING", - "enum": [ - "celsius", - "fahrenheit", - ], - }, - }, - "required": ["location"], - }, -) - def test_init(monkeypatch): monkeypatch.setenv("GOOGLE_API_KEY", "test") @@ -41,40 +21,24 @@ def test_init(monkeypatch): top_k=0.5, ) safety_settings = {HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH} - get_current_weather_func = FunctionDeclaration( - name="get_current_weather", - description="Get the current weather in a given location", - parameters={ - "type_": "OBJECT", - "properties": { - "location": {"type_": "STRING", "description": "The city and state, e.g. San Francisco, CA"}, - "unit": { - "type_": "STRING", - "enum": [ - "celsius", - "fahrenheit", - ], - }, - }, - "required": ["location"], - }, - ) - tool = Tool(function_declarations=[get_current_weather_func]) with patch("haystack_integrations.components.generators.google_ai.gemini.genai.configure") as mock_genai_configure: gemini = GoogleAIGeminiGenerator( generation_config=generation_config, safety_settings=safety_settings, - tools=[tool], ) mock_genai_configure.assert_called_once_with(api_key="test") assert gemini._model_name == "gemini-1.5-flash" assert gemini._generation_config == generation_config assert gemini._safety_settings == safety_settings - assert gemini._tools == [tool] assert isinstance(gemini._model, GenerativeModel) +def test_init_fails_with_tools(): + with pytest.raises(TypeError, match="GoogleAIGeminiGenerator does not support the `tools` parameter."): + GoogleAIGeminiGenerator(tools=["tool1", "tool2"]) + + def test_to_dict(monkeypatch): monkeypatch.setenv("GOOGLE_API_KEY", "test") @@ -88,7 +52,6 @@ def test_to_dict(monkeypatch): "generation_config": None, "safety_settings": None, "streaming_callback": None, - "tools": None, }, } @@ -105,32 +68,11 @@ def test_to_dict_with_param(monkeypatch): top_k=2, ) safety_settings = {HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH} - get_current_weather_func = FunctionDeclaration( - name="get_current_weather", - description="Get the current weather in a given location", - parameters={ - "type_": "OBJECT", - "properties": { - "location": {"type_": "STRING", "description": "The city and state, e.g. San Francisco, CA"}, - "unit": { - "type_": "STRING", - "enum": [ - "celsius", - "fahrenheit", - ], - }, - }, - "required": ["location"], - }, - ) - - tool = Tool(function_declarations=[get_current_weather_func]) with patch("haystack_integrations.components.generators.google_ai.gemini.genai.configure"): gemini = GoogleAIGeminiGenerator( generation_config=generation_config, safety_settings=safety_settings, - tools=[tool], ) assert gemini.to_dict() == { "type": "haystack_integrations.components.generators.google_ai.gemini.GoogleAIGeminiGenerator", @@ -147,11 +89,6 @@ def test_to_dict_with_param(monkeypatch): }, "safety_settings": {10: 3}, "streaming_callback": None, - "tools": [ - b"\n\xad\x01\n\x13get_current_weather\x12+Get the current weather in a given location\x1ai" - b"\x08\x06:\x1f\n\x04unit\x12\x17\x08\x01*\x07celsius*\nfahrenheit::\n\x08location\x12.\x08" - b"\x01\x1a*The city and state, e.g. San Francisco, CAB\x08location" - ], }, } @@ -175,11 +112,6 @@ def test_from_dict_with_param(monkeypatch): }, "safety_settings": {10: 3}, "streaming_callback": None, - "tools": [ - b"\n\xad\x01\n\x13get_current_weather\x12+Get the current weather in a given location\x1ai" - b"\x08\x06:\x1f\n\x04unit\x12\x17\x08\x01*\x07celsius*\nfahrenheit::\n\x08location\x12.\x08" - b"\x01\x1a*The city and state, e.g. San Francisco, CAB\x08location" - ], }, } ) @@ -194,7 +126,6 @@ def test_from_dict_with_param(monkeypatch): top_k=0.5, ) assert gemini._safety_settings == {HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH} - assert gemini._tools == [Tool(function_declarations=[GET_CURRENT_WEATHER_FUNC])] assert isinstance(gemini._model, GenerativeModel) @@ -217,11 +148,6 @@ def test_from_dict(monkeypatch): }, "safety_settings": {10: 3}, "streaming_callback": None, - "tools": [ - b"\n\xad\x01\n\x13get_current_weather\x12+Get the current weather in a given location\x1ai" - b"\x08\x06:\x1f\n\x04unit\x12\x17\x08\x01*\x07celsius*\nfahrenheit::\n\x08location\x12.\x08" - b"\x01\x1a*The city and state, e.g. San Francisco, CAB\x08location" - ], }, } ) @@ -236,7 +162,6 @@ def test_from_dict(monkeypatch): top_k=0.5, ) assert gemini._safety_settings == {HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH} - assert gemini._tools == [Tool(function_declarations=[GET_CURRENT_WEATHER_FUNC])] assert isinstance(gemini._model, GenerativeModel) From 394f7e11f40cd1d7307fcb6e25a86000a4da914f Mon Sep 17 00:00:00 2001 From: HaystackBot Date: Tue, 12 Nov 2024 11:57:01 +0000 Subject: [PATCH 066/229] Update the changelog --- integrations/google_ai/CHANGELOG.md | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/integrations/google_ai/CHANGELOG.md b/integrations/google_ai/CHANGELOG.md index 3f3ecaf79..8f09db79a 100644 --- a/integrations/google_ai/CHANGELOG.md +++ b/integrations/google_ai/CHANGELOG.md @@ -1,5 +1,15 @@ # Changelog +## [integrations/google_ai-v3.0.0] - 2024-11-12 + +### 🐛 Bug Fixes + +- `GoogleAIGeminiGenerator` - remove support for tools and change output type (#1177) + +### ⚙️ Miscellaneous Tasks + +- Adopt uv as installer (#1142) + ## [integrations/google_ai-v2.0.1] - 2024-10-15 ### 🚀 Features From 72e25306169a807ea8736c14d866bfca0b6974d7 Mon Sep 17 00:00:00 2001 From: ArzelaAscoIi <37148029+ArzelaAscoIi@users.noreply.github.com> Date: Wed, 13 Nov 2024 09:43:26 +0100 Subject: [PATCH 067/229] fix: dependency for weaviate document store (#1186) * fix: dependency for weaviate document store * fix * fix * well. now finally --- integrations/weaviate/pyproject.toml | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/integrations/weaviate/pyproject.toml b/integrations/weaviate/pyproject.toml index 70b045bc4..e88397df9 100644 --- a/integrations/weaviate/pyproject.toml +++ b/integrations/weaviate/pyproject.toml @@ -26,7 +26,6 @@ classifiers = [ dependencies = [ "haystack-ai", "weaviate-client>=4.9", - "haystack-pydoc-tools", "python-dateutil", ] @@ -48,7 +47,7 @@ git_describe_command = 'git describe --tags --match="integrations/weaviate-v[0-9 [tool.hatch.envs.default] installer = "uv" -dependencies = ["coverage[toml]>=6.5", "pytest", "pytest-rerunfailures", "ipython"] +dependencies = ["coverage[toml]>=6.5", "pytest", "pytest-rerunfailures", "ipython", "haystack-pydoc-tools"] [tool.hatch.envs.default.scripts] test = "pytest {args:tests}" test-cov = "coverage run -m pytest {args:tests}" From 051caadeebc59e85913ec69fc6fb95756cd8f2b8 Mon Sep 17 00:00:00 2001 From: HaystackBot Date: Wed, 13 Nov 2024 08:47:01 +0000 Subject: [PATCH 068/229] Update the changelog --- integrations/weaviate/CHANGELOG.md | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/integrations/weaviate/CHANGELOG.md b/integrations/weaviate/CHANGELOG.md index 6ffe0e60b..7f620c3a0 100644 --- a/integrations/weaviate/CHANGELOG.md +++ b/integrations/weaviate/CHANGELOG.md @@ -1,5 +1,11 @@ # Changelog +## [integrations/weaviate-v4.0.2] - 2024-11-13 + +### 🐛 Bug Fixes + +- Dependency for weaviate document store (#1186) + ## [integrations/weaviate-v4.0.1] - 2024-11-11 ## [integrations/weaviate-v4.0.0] - 2024-10-18 From a221ad59852a3f9aac665139733be59d314dcf92 Mon Sep 17 00:00:00 2001 From: Matthew Farrellee Date: Wed, 13 Nov 2024 05:17:22 -0500 Subject: [PATCH 069/229] add nvidia/llama-3.2-nv-rerankqa-1b-v1 to set of known ranking models (#1183) --- .../haystack_integrations/components/rankers/nvidia/ranker.py | 1 + 1 file changed, 1 insertion(+) diff --git a/integrations/nvidia/src/haystack_integrations/components/rankers/nvidia/ranker.py b/integrations/nvidia/src/haystack_integrations/components/rankers/nvidia/ranker.py index 46c736883..1553d1ac3 100644 --- a/integrations/nvidia/src/haystack_integrations/components/rankers/nvidia/ranker.py +++ b/integrations/nvidia/src/haystack_integrations/components/rankers/nvidia/ranker.py @@ -12,6 +12,7 @@ _MODEL_ENDPOINT_MAP = { "nvidia/nv-rerankqa-mistral-4b-v3": "https://ai.api.nvidia.com/v1/retrieval/nvidia/nv-rerankqa-mistral-4b-v3/reranking", + "nvidia/llama-3.2-nv-rerankqa-1b-v1": "https://ai.api.nvidia.com/v1/retrieval/nvidia/llama-3_2-nv-rerankqa-1b-v1/reranking", } From c3437763e7c80e11e9ed8678f1b235353a316781 Mon Sep 17 00:00:00 2001 From: HaystackBot Date: Wed, 13 Nov 2024 10:18:34 +0000 Subject: [PATCH 070/229] Update the changelog --- integrations/nvidia/CHANGELOG.md | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/integrations/nvidia/CHANGELOG.md b/integrations/nvidia/CHANGELOG.md index f66536fe5..75b31d033 100644 --- a/integrations/nvidia/CHANGELOG.md +++ b/integrations/nvidia/CHANGELOG.md @@ -1,15 +1,17 @@ # Changelog -## [unreleased] +## [integrations/nvidia-v0.1.0] - 2024-11-13 ### 🚀 Features - Update default embedding model to nvidia/nv-embedqa-e5-v5 (#1015) - Add NVIDIA NIM ranker support (#1023) +- Raise error when attempting to embed empty documents/strings with Nvidia embedders (#1118) ### 🐛 Bug Fixes - Lints in `nvidia-haystack` (#993) +- Missing Nvidia embedding truncate mode (#1043) ### 🚜 Refactor @@ -27,6 +29,8 @@ - Retry tests to reduce flakyness (#836) - Update ruff invocation to include check parameter (#853) +- Update ruff linting scripts and settings (#1105) +- Adopt uv as installer (#1142) ### Docs From 55a1cbf56d2d9da253e57aea3aadb8b68da9514a Mon Sep 17 00:00:00 2001 From: paulmartrencharpro <148542350+paulmartrencharpro@users.noreply.github.com> Date: Wed, 13 Nov 2024 19:59:15 +0100 Subject: [PATCH 071/229] squashing (#1178) Co-authored-by: anakin87 --- README.md | 2 +- .../fastembed/examples/ranker_example.py | 22 ++ integrations/fastembed/pydoc/config.yml | 3 +- integrations/fastembed/pyproject.toml | 6 +- .../components/rankers/fastembed/__init__.py | 3 + .../components/rankers/fastembed/ranker.py | 202 ++++++++++++ .../fastembed/tests/test_fastembed_ranker.py | 292 ++++++++++++++++++ 7 files changed, 527 insertions(+), 3 deletions(-) create mode 100644 integrations/fastembed/examples/ranker_example.py create mode 100644 integrations/fastembed/src/haystack_integrations/components/rankers/fastembed/__init__.py create mode 100644 integrations/fastembed/src/haystack_integrations/components/rankers/fastembed/ranker.py create mode 100644 integrations/fastembed/tests/test_fastembed_ranker.py diff --git a/README.md b/README.md index 2b4a83253..af83d045d 100644 --- a/README.md +++ b/README.md @@ -34,7 +34,7 @@ Please check out our [Contribution Guidelines](CONTRIBUTING.md) for all the deta | [cohere-haystack](integrations/cohere/) | Embedder, Generator, Ranker | [![PyPI - Version](https://img.shields.io/pypi/v/cohere-haystack.svg)](https://pypi.org/project/cohere-haystack) | [![Test / cohere](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/cohere.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/cohere.yml) | | [deepeval-haystack](integrations/deepeval/) | Evaluator | [![PyPI - Version](https://img.shields.io/pypi/v/deepeval-haystack.svg)](https://pypi.org/project/deepeval-haystack) | [![Test / deepeval](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/deepeval.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/deepeval.yml) | | [elasticsearch-haystack](integrations/elasticsearch/) | Document Store | [![PyPI - Version](https://img.shields.io/pypi/v/elasticsearch-haystack.svg)](https://pypi.org/project/elasticsearch-haystack) | [![Test / elasticsearch](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/elasticsearch.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/elasticsearch.yml) | -| [fastembed-haystack](integrations/fastembed/) | Embedder | [![PyPI - Version](https://img.shields.io/pypi/v/fastembed-haystack.svg)](https://pypi.org/project/fastembed-haystack/) | [![Test / fastembed](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/fastembed.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/fastembed.yml) | +| [fastembed-haystack](integrations/fastembed/) | Embedder, Ranker | [![PyPI - Version](https://img.shields.io/pypi/v/fastembed-haystack.svg)](https://pypi.org/project/fastembed-haystack/) | [![Test / fastembed](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/fastembed.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/fastembed.yml) | | [google-ai-haystack](integrations/google_ai/) | Generator | [![PyPI - Version](https://img.shields.io/pypi/v/google-ai-haystack.svg)](https://pypi.org/project/google-ai-haystack) | [![Test / google-ai](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/google_ai.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/google_ai.yml) | | [google-vertex-haystack](integrations/google_vertex/) | Generator | [![PyPI - Version](https://img.shields.io/pypi/v/google-vertex-haystack.svg)](https://pypi.org/project/google-vertex-haystack) | [![Test / google-vertex](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/google_vertex.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/google_vertex.yml) | | [instructor-embedders-haystack](integrations/instructor_embedders/) | Embedder | [![PyPI - Version](https://img.shields.io/pypi/v/instructor-embedders-haystack.svg)](https://pypi.org/project/instructor-embedders-haystack) | [![Test / instructor-embedders](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/instructor_embedders.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/instructor_embedders.yml) | diff --git a/integrations/fastembed/examples/ranker_example.py b/integrations/fastembed/examples/ranker_example.py new file mode 100644 index 000000000..7a31e4646 --- /dev/null +++ b/integrations/fastembed/examples/ranker_example.py @@ -0,0 +1,22 @@ +from haystack import Document + +from haystack_integrations.components.rankers.fastembed import FastembedRanker + +query = "Who is maintaining Qdrant?" +documents = [ + Document( + content="This is built to be faster and lighter than other embedding libraries e.g. Transformers, Sentence-Transformers, etc." + ), + Document(content="fastembed is supported by and maintained by Qdrant."), +] + +ranker = FastembedRanker(model_name="Xenova/ms-marco-MiniLM-L-6-v2") +ranker.warm_up() +reranked_documents = ranker.run(query=query, documents=documents)["documents"] + + +print(reranked_documents["documents"][0]) + +# Document(id=..., +# content: 'fastembed is supported by and maintained by Qdrant.', +# score: 5.472434997558594..) diff --git a/integrations/fastembed/pydoc/config.yml b/integrations/fastembed/pydoc/config.yml index aad50e52c..8ab538cf8 100644 --- a/integrations/fastembed/pydoc/config.yml +++ b/integrations/fastembed/pydoc/config.yml @@ -6,7 +6,8 @@ loaders: "haystack_integrations.components.embedders.fastembed.fastembed_document_embedder", "haystack_integrations.components.embedders.fastembed.fastembed_text_embedder", "haystack_integrations.components.embedders.fastembed.fastembed_sparse_document_embedder", - "haystack_integrations.components.embedders.fastembed.fastembed_sparse_text_embedder" + "haystack_integrations.components.embedders.fastembed.fastembed_sparse_text_embedder", + "haystack_integrations.components.rankers.fastembed.ranker" ] ignore_when_discovered: ["__init__"] processors: diff --git a/integrations/fastembed/pyproject.toml b/integrations/fastembed/pyproject.toml index b9f1f6cfd..abae78d8a 100644 --- a/integrations/fastembed/pyproject.toml +++ b/integrations/fastembed/pyproject.toml @@ -23,7 +23,7 @@ classifiers = [ "Programming Language :: Python :: Implementation :: CPython", "Programming Language :: Python :: Implementation :: PyPy", ] -dependencies = ["haystack-ai>=2.0.1", "fastembed>=0.2.5", "onnxruntime<1.20.0"] +dependencies = ["haystack-ai>=2.0.1", "fastembed>=0.4.2"] [project.urls] Source = "https://github.com/deepset-ai/haystack-core-integrations" @@ -154,6 +154,10 @@ omit = ["*/tests/*", "*/__init__.py"] show_missing = true exclude_lines = ["no cov", "if __name__ == .__main__.:", "if TYPE_CHECKING:"] +[tool.pytest.ini_options] +minversion = "6.0" +markers = ["unit: unit tests", "integration: integration tests"] + [[tool.mypy.overrides]] module = [ "haystack.*", diff --git a/integrations/fastembed/src/haystack_integrations/components/rankers/fastembed/__init__.py b/integrations/fastembed/src/haystack_integrations/components/rankers/fastembed/__init__.py new file mode 100644 index 000000000..ece5e858b --- /dev/null +++ b/integrations/fastembed/src/haystack_integrations/components/rankers/fastembed/__init__.py @@ -0,0 +1,3 @@ +from .ranker import FastembedRanker + +__all__ = ["FastembedRanker"] diff --git a/integrations/fastembed/src/haystack_integrations/components/rankers/fastembed/ranker.py b/integrations/fastembed/src/haystack_integrations/components/rankers/fastembed/ranker.py new file mode 100644 index 000000000..8f077a30c --- /dev/null +++ b/integrations/fastembed/src/haystack_integrations/components/rankers/fastembed/ranker.py @@ -0,0 +1,202 @@ +from typing import Any, Dict, List, Optional + +from haystack import Document, component, default_from_dict, default_to_dict, logging + +from fastembed.rerank.cross_encoder import TextCrossEncoder + +logger = logging.getLogger(__name__) + + +@component +class FastembedRanker: + """ + Ranks Documents based on their similarity to the query using + [Fastembed models](https://qdrant.github.io/fastembed/examples/Supported_Models/). + + Documents are indexed from most to least semantically relevant to the query. + + Usage example: + ```python + from haystack import Document + from haystack_integrations.components.rankers.fastembed import FastembedRanker + + ranker = FastembedRanker(model_name="Xenova/ms-marco-MiniLM-L-6-v2", top_k=2) + + docs = [Document(content="Paris"), Document(content="Berlin")] + query = "What is the capital of germany?" + output = ranker.run(query=query, documents=docs) + print(output["documents"][0].content) + + # Berlin + ``` + """ + + def __init__( + self, + model_name: str = "Xenova/ms-marco-MiniLM-L-6-v2", + top_k: int = 10, + cache_dir: Optional[str] = None, + threads: Optional[int] = None, + batch_size: int = 64, + parallel: Optional[int] = None, + local_files_only: bool = False, + meta_fields_to_embed: Optional[List[str]] = None, + meta_data_separator: str = "\n", + ): + """ + Creates an instance of the 'FastembedRanker'. + + :param model_name: Fastembed model name. Check the list of supported models in the [Fastembed documentation](https://qdrant.github.io/fastembed/examples/Supported_Models/). + :param top_k: The maximum number of documents to return. + :param cache_dir: The path to the cache directory. + Can be set using the `FASTEMBED_CACHE_PATH` env variable. + Defaults to `fastembed_cache` in the system's temp directory. + :param threads: The number of threads single onnxruntime session can use. Defaults to None. + :param batch_size: Number of strings to encode at once. + :param parallel: + If > 1, data-parallel encoding will be used, recommended for offline encoding of large datasets. + If 0, use all available cores. + If None, don't use data-parallel processing, use default onnxruntime threading instead. + :param local_files_only: If `True`, only use the model files in the `cache_dir`. + :param meta_fields_to_embed: List of meta fields that should be concatenated + with the document content for reranking. + :param meta_data_separator: Separator used to concatenate the meta fields + to the Document content. + """ + if top_k <= 0: + msg = f"top_k must be > 0, but got {top_k}" + raise ValueError(msg) + + self.model_name = model_name + self.top_k = top_k + self.cache_dir = cache_dir + self.threads = threads + self.batch_size = batch_size + self.parallel = parallel + self.local_files_only = local_files_only + self.meta_fields_to_embed = meta_fields_to_embed or [] + self.meta_data_separator = meta_data_separator + self._model = None + + def to_dict(self) -> Dict[str, Any]: + """ + Serializes the component to a dictionary. + + :returns: + Dictionary with serialized data. + """ + return default_to_dict( + self, + model_name=self.model_name, + top_k=self.top_k, + cache_dir=self.cache_dir, + threads=self.threads, + batch_size=self.batch_size, + parallel=self.parallel, + local_files_only=self.local_files_only, + meta_fields_to_embed=self.meta_fields_to_embed, + meta_data_separator=self.meta_data_separator, + ) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "FastembedRanker": + """ + Deserializes the component from a dictionary. + + :param data: + The dictionary to deserialize from. + :returns: + The deserialized component. + """ + return default_from_dict(cls, data) + + def warm_up(self): + """ + Initializes the component. + """ + if self._model is None: + self._model = TextCrossEncoder( + model_name=self.model_name, + cache_dir=self.cache_dir, + threads=self.threads, + local_files_only=self.local_files_only, + ) + + def _prepare_fastembed_input_docs(self, documents: List[Document]) -> List[str]: + """ + Prepare the input by concatenating the document text with the metadata fields specified. + :param documents: The list of Document objects. + + :return: A list of strings to be given as input to Fastembed model. + """ + concatenated_input_list = [] + for doc in documents: + meta_values_to_embed = [ + str(doc.meta[key]) for key in self.meta_fields_to_embed if key in doc.meta and doc.meta.get(key) + ] + concatenated_input = self.meta_data_separator.join([*meta_values_to_embed, doc.content or ""]) + concatenated_input_list.append(concatenated_input) + + return concatenated_input_list + + @component.output_types(documents=List[Document]) + def run(self, query: str, documents: List[Document], top_k: Optional[int] = None): + """ + Returns a list of documents ranked by their similarity to the given query, using FastEmbed. + + :param query: + The input query to compare the documents to. + :param documents: + A list of documents to be ranked. + :param top_k: + The maximum number of documents to return. + + :returns: + A dictionary with the following keys: + - `documents`: A list of documents closest to the query, sorted from most similar to least similar. + + :raises ValueError: If `top_k` is not > 0. + """ + if not isinstance(documents, list) or documents and not isinstance(documents[0], Document): + msg = "FastembedRanker expects a list of Documents as input. " + raise TypeError(msg) + if query == "": + msg = "No query provided" + raise ValueError(msg) + + if not documents: + return {"documents": []} + + top_k = top_k or self.top_k + if top_k <= 0: + msg = f"top_k must be > 0, but got {top_k}" + raise ValueError(msg) + + if self._model is None: + msg = "The ranker model has not been loaded. Please call warm_up() before running." + raise RuntimeError(msg) + + fastembed_input_docs = self._prepare_fastembed_input_docs(documents) + + scores = list( + self._model.rerank( + query=query, + documents=fastembed_input_docs, + batch_size=self.batch_size, + parallel=self.parallel, + ) + ) + + # Combine the two lists into a single list of tuples + doc_scores = list(zip(documents, scores)) + + # Sort the list of tuples by the score in descending order + sorted_doc_scores = sorted(doc_scores, key=lambda x: x[1], reverse=True) + + # Get the top_k documents + top_k_documents = [] + for doc, score in sorted_doc_scores[:top_k]: + doc.score = score + top_k_documents.append(doc) + + return {"documents": top_k_documents} diff --git a/integrations/fastembed/tests/test_fastembed_ranker.py b/integrations/fastembed/tests/test_fastembed_ranker.py new file mode 100644 index 000000000..e38229c87 --- /dev/null +++ b/integrations/fastembed/tests/test_fastembed_ranker.py @@ -0,0 +1,292 @@ +from unittest.mock import MagicMock + +import pytest +from haystack import Document, default_from_dict + +from haystack_integrations.components.rankers.fastembed.ranker import ( + FastembedRanker, +) + + +class TestFastembedRanker: + def test_init_default(self): + """ + Test default initialization parameters for FastembedRanker. + """ + ranker = FastembedRanker(model_name="BAAI/bge-reranker-base") + assert ranker.model_name == "BAAI/bge-reranker-base" + assert ranker.top_k == 10 + assert ranker.cache_dir is None + assert ranker.threads is None + assert ranker.batch_size == 64 + assert ranker.parallel is None + assert not ranker.local_files_only + assert ranker.meta_fields_to_embed == [] + assert ranker.meta_data_separator == "\n" + + def test_init_with_parameters(self): + """ + Test custom initialization parameters for FastembedRanker. + """ + ranker = FastembedRanker( + model_name="BAAI/bge-reranker-base", + top_k=64, + cache_dir="fake_dir", + threads=2, + batch_size=50, + parallel=1, + local_files_only=True, + meta_fields_to_embed=["test_field"], + meta_data_separator=" | ", + ) + assert ranker.model_name == "BAAI/bge-reranker-base" + assert ranker.top_k == 64 + assert ranker.cache_dir == "fake_dir" + assert ranker.threads == 2 + assert ranker.batch_size == 50 + assert ranker.parallel == 1 + assert ranker.local_files_only + assert ranker.meta_fields_to_embed == ["test_field"] + assert ranker.meta_data_separator == " | " + + def test_init_with_incorrect_input(self): + """ + Test for checking incorrect input format on init + """ + with pytest.raises( + ValueError, + match="top_k must be > 0, but got 0", + ): + FastembedRanker(model_name="Xenova/ms-marco-MiniLM-L-12-v2", top_k=0) + + with pytest.raises( + ValueError, + match="top_k must be > 0, but got -3", + ): + FastembedRanker(model_name="Xenova/ms-marco-MiniLM-L-12-v2", top_k=-3) + + def test_to_dict(self): + """ + Test serialization of FastembedRanker to a dictionary, using default initialization parameters. + """ + ranker = FastembedRanker(model_name="BAAI/bge-reranker-base") + ranker_dict = ranker.to_dict() + assert ranker_dict == { + "type": "haystack_integrations.components.rankers.fastembed.ranker.FastembedRanker", + "init_parameters": { + "model_name": "BAAI/bge-reranker-base", + "top_k": 10, + "cache_dir": None, + "threads": None, + "batch_size": 64, + "parallel": None, + "local_files_only": False, + "meta_fields_to_embed": [], + "meta_data_separator": "\n", + }, + } + + def test_to_dict_with_custom_init_parameters(self): + """ + Test serialization of FastembedRanker to a dictionary, using custom initialization parameters. + """ + ranker = FastembedRanker( + model_name="BAAI/bge-reranker-base", + cache_dir="fake_dir", + threads=2, + top_k=5, + batch_size=50, + parallel=1, + local_files_only=True, + meta_fields_to_embed=["test_field"], + meta_data_separator=" | ", + ) + ranker_dict = ranker.to_dict() + assert ranker_dict == { + "type": "haystack_integrations.components.rankers.fastembed.ranker.FastembedRanker", + "init_parameters": { + "model_name": "BAAI/bge-reranker-base", + "cache_dir": "fake_dir", + "threads": 2, + "top_k": 5, + "batch_size": 50, + "parallel": 1, + "local_files_only": True, + "meta_fields_to_embed": ["test_field"], + "meta_data_separator": " | ", + }, + } + + def test_from_dict(self): + """ + Test deserialization of FastembedRanker from a dictionary, using default initialization parameters. + """ + ranker_dict = { + "type": "haystack_integrations.components.rankers.fastembed.ranker.FastembedRanker", + "init_parameters": { + "model_name": "BAAI/bge-reranker-base", + "cache_dir": None, + "threads": None, + "top_k": 5, + "batch_size": 50, + "parallel": None, + "local_files_only": False, + "meta_fields_to_embed": [], + "meta_data_separator": "\n", + }, + } + ranker = default_from_dict(FastembedRanker, ranker_dict) + assert ranker.model_name == "BAAI/bge-reranker-base" + assert ranker.cache_dir is None + assert ranker.threads is None + assert ranker.top_k == 5 + assert ranker.batch_size == 50 + assert ranker.parallel is None + assert not ranker.local_files_only + assert ranker.meta_fields_to_embed == [] + assert ranker.meta_data_separator == "\n" + + def test_from_dict_with_custom_init_parameters(self): + """ + Test deserialization of FastembedRanker from a dictionary, using custom initialization parameters. + """ + ranker_dict = { + "type": "haystack_integrations.components.rankers.fastembed.ranker.FastembedRanker", + "init_parameters": { + "model_name": "BAAI/bge-reranker-base", + "cache_dir": "fake_dir", + "threads": 2, + "top_k": 5, + "batch_size": 50, + "parallel": 1, + "local_files_only": True, + "meta_fields_to_embed": ["test_field"], + "meta_data_separator": " | ", + }, + } + ranker = default_from_dict(FastembedRanker, ranker_dict) + assert ranker.model_name == "BAAI/bge-reranker-base" + assert ranker.cache_dir == "fake_dir" + assert ranker.threads == 2 + assert ranker.top_k == 5 + assert ranker.batch_size == 50 + assert ranker.parallel == 1 + assert ranker.local_files_only + assert ranker.meta_fields_to_embed == ["test_field"] + assert ranker.meta_data_separator == " | " + + def test_run_incorrect_input_format(self): + """ + Test for checking incorrect input format. + """ + ranker = FastembedRanker(model_name="Xenova/ms-marco-MiniLM-L-12-v2") + ranker._model = "mock_model" + + query = "query" + string_input = "text" + list_integers_input = [1, 2, 3] + list_document = [Document("Document 1")] + + with pytest.raises( + TypeError, + match="FastembedRanker expects a list of Documents as input.", + ): + ranker.run(query=query, documents=string_input) + + with pytest.raises( + TypeError, + match="FastembedRanker expects a list of Documents as input.", + ): + ranker.run(query=query, documents=list_integers_input) + + with pytest.raises( + ValueError, + match="No query provided", + ): + ranker.run(query="", documents=list_document) + + with pytest.raises( + ValueError, + match="top_k must be > 0, but got -3", + ): + ranker.run(query=query, documents=list_document, top_k=-3) + + def test_run_no_warmup(self): + """ + Test for checking error when calling without a warmup. + """ + ranker = FastembedRanker(model_name="Xenova/ms-marco-MiniLM-L-12-v2") + + query = "query" + list_document = [Document("Document 1")] + + with pytest.raises( + RuntimeError, + ): + ranker.run(query=query, documents=list_document) + + def test_run_empty_document_list(self): + """ + Test for no error when sending no documents. + """ + ranker = FastembedRanker(model_name="Xenova/ms-marco-MiniLM-L-12-v2") + ranker._model = "mock_model" + + query = "query" + list_document = [] + + result = ranker.run(query=query, documents=list_document) + assert len(result["documents"]) == 0 + + def test_embed_metadata(self): + """ + Tests the embedding of metadata fields in document content for ranking. + """ + ranker = FastembedRanker( + model_name="model_name", + meta_fields_to_embed=["meta_field"], + ) + ranker._model = MagicMock() + + documents = [Document(content=f"document-number {i}", meta={"meta_field": f"meta_value {i}"}) for i in range(5)] + query = "test" + ranker.run(query=query, documents=documents) + + ranker._model.rerank.assert_called_once_with( + query=query, + documents=[ + "meta_value 0\ndocument-number 0", + "meta_value 1\ndocument-number 1", + "meta_value 2\ndocument-number 2", + "meta_value 3\ndocument-number 3", + "meta_value 4\ndocument-number 4", + ], + batch_size=64, + parallel=None, + ) + + @pytest.mark.integration + def test_run(self): + ranker = FastembedRanker(model_name="Xenova/ms-marco-MiniLM-L-6-v2", top_k=2) + ranker.warm_up() + + query = "Who is maintaining Qdrant?" + documents = [ + Document( + content="This is built to be faster and lighter than other embedding \ +libraries e.g. Transformers, Sentence-Transformers, etc." + ), + Document(content="This is some random input"), + Document(content="fastembed is supported by and maintained by Qdrant."), + ] + + result = ranker.run(query=query, documents=documents) + + assert len(result["documents"]) == 2 + first_document = result["documents"][0] + second_document = result["documents"][1] + + assert isinstance(first_document, Document) + assert isinstance(second_document, Document) + assert first_document.content == "fastembed is supported by and maintained by Qdrant." + assert first_document.score > second_document.score From 025a05a85ceeef43fb678399b10cde4d0a450d77 Mon Sep 17 00:00:00 2001 From: HaystackBot Date: Wed, 13 Nov 2024 19:00:48 +0000 Subject: [PATCH 072/229] Update the changelog --- integrations/fastembed/CHANGELOG.md | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/integrations/fastembed/CHANGELOG.md b/integrations/fastembed/CHANGELOG.md index b5c194d8b..5dd62d130 100644 --- a/integrations/fastembed/CHANGELOG.md +++ b/integrations/fastembed/CHANGELOG.md @@ -1,5 +1,11 @@ # Changelog +## [integrations/fastembed-v1.4.0] - 2024-11-13 + +### ⚙️ Miscellaneous Tasks + +- Adopt uv as installer (#1142) + ## [integrations/fastembed-v1.3.0] - 2024-10-07 ### 🚀 Features From ee4ca757d093c67b46215193faaa139def0e9c8c Mon Sep 17 00:00:00 2001 From: Stefano Fiorucci Date: Thu, 14 Nov 2024 10:48:22 +0100 Subject: [PATCH 073/229] fix: deepeval - pin indirect dependencies based on python version (#1187) * try pinning pydantic * retry * again * more precise pin * fix * better --- integrations/deepeval/pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/integrations/deepeval/pyproject.toml b/integrations/deepeval/pyproject.toml index 6ef64387b..78cc2542a 100644 --- a/integrations/deepeval/pyproject.toml +++ b/integrations/deepeval/pyproject.toml @@ -22,7 +22,7 @@ classifiers = [ "Programming Language :: Python :: Implementation :: CPython", "Programming Language :: Python :: Implementation :: PyPy", ] -dependencies = ["haystack-ai", "deepeval==0.20.57"] +dependencies = ["haystack-ai", "deepeval==0.20.57", "langchain<0.3; python_version < '3.10'"] [project.urls] Source = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/deepeval" From a986ace82f0906b948d5e350a9f5a160f770059b Mon Sep 17 00:00:00 2001 From: HaystackBot Date: Thu, 14 Nov 2024 09:50:16 +0000 Subject: [PATCH 074/229] Update the changelog --- integrations/deepeval/CHANGELOG.md | 35 ++++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) create mode 100644 integrations/deepeval/CHANGELOG.md diff --git a/integrations/deepeval/CHANGELOG.md b/integrations/deepeval/CHANGELOG.md new file mode 100644 index 000000000..a296c7cfa --- /dev/null +++ b/integrations/deepeval/CHANGELOG.md @@ -0,0 +1,35 @@ +# Changelog + +## [integrations/deepeval-v0.1.2] - 2024-11-14 + +### 🚀 Features + +- Implement `DeepEvalEvaluator` (#346) + +### 🐛 Bug Fixes + +- Fix order of API docs (#447) + +This PR will also push the docs to Readme +- Deepeval - pin indirect dependencies based on python version (#1187) + +### 📚 Documentation + +- Update paths and titles (#397) +- Update category slug (#442) +- Update `deepeval-haystack` docstrings (#527) +- Disable-class-def (#556) + +### 🧪 Testing + +- Do not retry tests in `hatch run test` command (#954) + +### ⚙️ Miscellaneous Tasks + +- Exculde evaluator private classes in API docs (#392) +- Retry tests to reduce flakyness (#836) +- Update ruff invocation to include check parameter (#853) +- Update ruff linting scripts and settings (#1105) +- Adopt uv as installer (#1142) + + From b5c9b2a8870fb079dada9afd93a07b9c043d8557 Mon Sep 17 00:00:00 2001 From: Stefano Fiorucci Date: Thu, 14 Nov 2024 12:23:23 +0100 Subject: [PATCH 075/229] fix: VertexAIGeminiGenerator - remove support for tools and change output type (#1180) --- .../generators/google_vertex/gemini.py | 63 ++------- .../google_vertex/tests/test_gemini.py | 125 ++---------------- 2 files changed, 21 insertions(+), 167 deletions(-) diff --git a/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/gemini.py b/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/gemini.py index 737f2e668..c9473b428 100644 --- a/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/gemini.py +++ b/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/gemini.py @@ -15,8 +15,6 @@ HarmBlockThreshold, HarmCategory, Part, - Tool, - ToolConfig, ) logger = logging.getLogger(__name__) @@ -50,6 +48,16 @@ class VertexAIGeminiGenerator: ``` """ + def __new__(cls, *_, **kwargs): + if "tools" in kwargs or "tool_config" in kwargs: + msg = ( + "VertexAIGeminiGenerator does not support `tools` and `tool_config` parameters. " + "Use VertexAIGeminiChatGenerator instead." + ) + raise TypeError(msg) + return super(VertexAIGeminiGenerator, cls).__new__(cls) # noqa: UP008 + # super(__class__, cls) is needed because of the component decorator + def __init__( self, *, @@ -58,8 +66,6 @@ def __init__( location: Optional[str] = None, generation_config: Optional[Union[GenerationConfig, Dict[str, Any]]] = None, safety_settings: Optional[Dict[HarmCategory, HarmBlockThreshold]] = None, - tools: Optional[List[Tool]] = None, - tool_config: Optional[ToolConfig] = None, system_instruction: Optional[Union[str, ByteStream, Part]] = None, streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, ): @@ -86,10 +92,6 @@ def __init__( for [HarmBlockThreshold](https://cloud.google.com/python/docs/reference/aiplatform/latest/vertexai.generative_models.HarmBlockThreshold) and [HarmCategory](https://cloud.google.com/python/docs/reference/aiplatform/latest/vertexai.generative_models.HarmCategory) for more details. - :param tools: List of tools to use when generating content. See the documentation for - [Tool](https://cloud.google.com/python/docs/reference/aiplatform/latest/vertexai.generative_models.Tool) - the list of supported arguments. - :param tool_config: The tool config to use. See the documentation for [ToolConfig](https://cloud.google.com/vertex-ai/generative-ai/docs/reference/python/latest/vertexai.generative_models.ToolConfig) :param system_instruction: Default system instruction to use for generating content. :param streaming_callback: A callback function that is called when a new token is received from the stream. The callback function accepts StreamingChunk as an argument. @@ -105,8 +107,6 @@ def __init__( # model parameters self._generation_config = generation_config self._safety_settings = safety_settings - self._tools = tools - self._tool_config = tool_config self._system_instruction = system_instruction self._streaming_callback = streaming_callback @@ -115,8 +115,6 @@ def __init__( self._model_name, generation_config=self._generation_config, safety_settings=self._safety_settings, - tools=self._tools, - tool_config=self._tool_config, system_instruction=self._system_instruction, ) @@ -132,18 +130,6 @@ def _generation_config_to_dict(self, config: Union[GenerationConfig, Dict[str, A "stop_sequences": config._raw_generation_config.stop_sequences, } - def _tool_config_to_dict(self, tool_config: ToolConfig) -> Dict[str, Any]: - """Serializes the ToolConfig object into a dictionary.""" - - mode = tool_config._gapic_tool_config.function_calling_config.mode - allowed_function_names = tool_config._gapic_tool_config.function_calling_config.allowed_function_names - config_dict = {"function_calling_config": {"mode": mode}} - - if allowed_function_names: - config_dict["function_calling_config"]["allowed_function_names"] = allowed_function_names - - return config_dict - def to_dict(self) -> Dict[str, Any]: """ Serializes the component to a dictionary. @@ -160,15 +146,10 @@ def to_dict(self) -> Dict[str, Any]: location=self._location, generation_config=self._generation_config, safety_settings=self._safety_settings, - tools=self._tools, - tool_config=self._tool_config, system_instruction=self._system_instruction, streaming_callback=callback_name, ) - if (tools := data["init_parameters"].get("tools")) is not None: - data["init_parameters"]["tools"] = [Tool.to_dict(t) for t in tools] - if (tool_config := data["init_parameters"].get("tool_config")) is not None: - data["init_parameters"]["tool_config"] = self._tool_config_to_dict(tool_config) + if (generation_config := data["init_parameters"].get("generation_config")) is not None: data["init_parameters"]["generation_config"] = self._generation_config_to_dict(generation_config) return data @@ -184,22 +165,8 @@ def from_dict(cls, data: Dict[str, Any]) -> "VertexAIGeminiGenerator": Deserialized component. """ - def _tool_config_from_dict(config_dict: Dict[str, Any]) -> ToolConfig: - """Deserializes the ToolConfig object from a dictionary.""" - function_calling_config = config_dict["function_calling_config"] - return ToolConfig( - function_calling_config=ToolConfig.FunctionCallingConfig( - mode=function_calling_config["mode"], - allowed_function_names=function_calling_config.get("allowed_function_names"), - ) - ) - - if (tools := data["init_parameters"].get("tools")) is not None: - data["init_parameters"]["tools"] = [Tool.from_dict(t) for t in tools] if (generation_config := data["init_parameters"].get("generation_config")) is not None: data["init_parameters"]["generation_config"] = GenerationConfig.from_dict(generation_config) - if (tool_config := data["init_parameters"].get("tool_config")) is not None: - data["init_parameters"]["tool_config"] = _tool_config_from_dict(tool_config) if (serialized_callback_handler := data["init_parameters"].get("streaming_callback")) is not None: data["init_parameters"]["streaming_callback"] = deserialize_callable(serialized_callback_handler) return default_from_dict(cls, data) @@ -215,7 +182,7 @@ def _convert_part(self, part: Union[str, ByteStream, Part]) -> Part: msg = f"Unsupported type {type(part)} for part {part}" raise ValueError(msg) - @component.output_types(replies=List[Union[str, Dict[str, str]]]) + @component.output_types(replies=List[str]) def run( self, parts: Variadic[Union[str, ByteStream, Part]], @@ -257,12 +224,6 @@ def _get_response(self, response_body: GenerationResponse) -> List[str]: for part in candidate.content.parts: if part._raw_part.text != "": replies.append(part.text) - elif part.function_call is not None: - function_call = { - "name": part.function_call.name, - "args": dict(part.function_call.args.items()), - } - replies.append(function_call) return replies def _get_stream_response( diff --git a/integrations/google_vertex/tests/test_gemini.py b/integrations/google_vertex/tests/test_gemini.py index 277851224..ff692c6f4 100644 --- a/integrations/google_vertex/tests/test_gemini.py +++ b/integrations/google_vertex/tests/test_gemini.py @@ -1,38 +1,17 @@ from unittest.mock import MagicMock, Mock, patch +import pytest from haystack import Pipeline from haystack.components.builders import PromptBuilder from haystack.dataclasses import StreamingChunk from vertexai.generative_models import ( - FunctionDeclaration, GenerationConfig, HarmBlockThreshold, HarmCategory, - Tool, - ToolConfig, ) from haystack_integrations.components.generators.google_vertex import VertexAIGeminiGenerator -GET_CURRENT_WEATHER_FUNC = FunctionDeclaration( - name="get_current_weather", - description="Get the current weather in a given location", - parameters={ - "type": "object", - "properties": { - "location": {"type": "string", "description": "The city and state, e.g. San Francisco, CA"}, - "unit": { - "type": "string", - "enum": [ - "celsius", - "fahrenheit", - ], - }, - }, - "required": ["location"], - }, -) - @patch("haystack_integrations.components.generators.google_vertex.gemini.vertexai_init") @patch("haystack_integrations.components.generators.google_vertex.gemini.GenerativeModel") @@ -48,32 +27,28 @@ def test_init(mock_vertexai_init, _mock_generative_model): ) safety_settings = {HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH} - tool = Tool(function_declarations=[GET_CURRENT_WEATHER_FUNC]) - tool_config = ToolConfig( - function_calling_config=ToolConfig.FunctionCallingConfig( - mode=ToolConfig.FunctionCallingConfig.Mode.ANY, - allowed_function_names=["get_current_weather_func"], - ) - ) - gemini = VertexAIGeminiGenerator( project_id="TestID123", location="TestLocation", generation_config=generation_config, safety_settings=safety_settings, - tools=[tool], - tool_config=tool_config, system_instruction="Please provide brief answers.", ) mock_vertexai_init.assert_called() assert gemini._model_name == "gemini-1.5-flash" assert gemini._generation_config == generation_config assert gemini._safety_settings == safety_settings - assert gemini._tools == [tool] - assert gemini._tool_config == tool_config assert gemini._system_instruction == "Please provide brief answers." +def test_init_fails_with_tools_or_tool_config(): + with pytest.raises(TypeError, match="VertexAIGeminiGenerator does not support `tools`"): + VertexAIGeminiGenerator(tools=["tool1", "tool2"]) + + with pytest.raises(TypeError, match="VertexAIGeminiGenerator does not support `tools`"): + VertexAIGeminiGenerator(tool_config={"custom": "config"}) + + @patch("haystack_integrations.components.generators.google_vertex.gemini.vertexai_init") @patch("haystack_integrations.components.generators.google_vertex.gemini.GenerativeModel") def test_to_dict(_mock_vertexai_init, _mock_generative_model): @@ -88,8 +63,6 @@ def test_to_dict(_mock_vertexai_init, _mock_generative_model): "generation_config": None, "safety_settings": None, "streaming_callback": None, - "tools": None, - "tool_config": None, "system_instruction": None, }, } @@ -108,21 +81,11 @@ def test_to_dict_with_params(_mock_vertexai_init, _mock_generative_model): ) safety_settings = {HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH} - tool = Tool(function_declarations=[GET_CURRENT_WEATHER_FUNC]) - tool_config = ToolConfig( - function_calling_config=ToolConfig.FunctionCallingConfig( - mode=ToolConfig.FunctionCallingConfig.Mode.ANY, - allowed_function_names=["get_current_weather_func"], - ) - ) - gemini = VertexAIGeminiGenerator( project_id="TestID123", location="TestLocation", generation_config=generation_config, safety_settings=safety_settings, - tools=[tool], - tool_config=tool_config, system_instruction="Please provide brief answers.", ) assert gemini.to_dict() == { @@ -141,34 +104,6 @@ def test_to_dict_with_params(_mock_vertexai_init, _mock_generative_model): }, "safety_settings": {HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH}, "streaming_callback": None, - "tools": [ - { - "function_declarations": [ - { - "name": "get_current_weather", - "description": "Get the current weather in a given location", - "parameters": { - "type_": "OBJECT", - "properties": { - "location": { - "type_": "STRING", - "description": "The city and state, e.g. San Francisco, CA", - }, - "unit": {"type_": "STRING", "enum": ["celsius", "fahrenheit"]}, - }, - "required": ["location"], - "property_ordering": ["location", "unit"], - }, - } - ] - } - ], - "tool_config": { - "function_calling_config": { - "mode": ToolConfig.FunctionCallingConfig.Mode.ANY, - "allowed_function_names": ["get_current_weather_func"], - } - }, "system_instruction": "Please provide brief answers.", }, } @@ -186,9 +121,7 @@ def test_from_dict(_mock_vertexai_init, _mock_generative_model): "model": "gemini-1.5-flash", "generation_config": None, "safety_settings": None, - "tools": None, "streaming_callback": None, - "tool_config": None, "system_instruction": None, }, } @@ -198,8 +131,6 @@ def test_from_dict(_mock_vertexai_init, _mock_generative_model): assert gemini._project_id is None assert gemini._location is None assert gemini._safety_settings is None - assert gemini._tools is None - assert gemini._tool_config is None assert gemini._system_instruction is None assert gemini._generation_config is None @@ -223,40 +154,7 @@ def test_from_dict_with_param(_mock_vertexai_init, _mock_generative_model): "stop_sequences": ["stop"], }, "safety_settings": {HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH}, - "tools": [ - { - "function_declarations": [ - { - "name": "get_current_weather", - "parameters": { - "type": "object", - "properties": { - "location": { - "type": "string", - "description": "The city and state, e.g. San Francisco, CA", - }, - "unit": { - "type": "string", - "enum": [ - "celsius", - "fahrenheit", - ], - }, - }, - "required": ["location"], - }, - "description": "Get the current weather in a given location", - } - ] - } - ], "streaming_callback": None, - "tool_config": { - "function_calling_config": { - "mode": ToolConfig.FunctionCallingConfig.Mode.ANY, - "allowed_function_names": ["get_current_weather_func"], - } - }, "system_instruction": "Please provide brief answers.", }, } @@ -266,13 +164,8 @@ def test_from_dict_with_param(_mock_vertexai_init, _mock_generative_model): assert gemini._project_id == "TestID123" assert gemini._location == "TestLocation" assert gemini._safety_settings == {HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH} - assert repr(gemini._tools) == repr([Tool(function_declarations=[GET_CURRENT_WEATHER_FUNC])]) assert isinstance(gemini._generation_config, GenerationConfig) - assert isinstance(gemini._tool_config, ToolConfig) assert gemini._system_instruction == "Please provide brief answers." - assert ( - gemini._tool_config._gapic_tool_config.function_calling_config.mode == ToolConfig.FunctionCallingConfig.Mode.ANY - ) @patch("haystack_integrations.components.generators.google_vertex.gemini.GenerativeModel") From c2d1b20b7e9e1be3a1e16acd13fe826060b97b67 Mon Sep 17 00:00:00 2001 From: HaystackBot Date: Thu, 14 Nov 2024 11:26:14 +0000 Subject: [PATCH 076/229] Update the changelog --- integrations/google_vertex/CHANGELOG.md | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/integrations/google_vertex/CHANGELOG.md b/integrations/google_vertex/CHANGELOG.md index ed2cc3c3b..ea2a8fb18 100644 --- a/integrations/google_vertex/CHANGELOG.md +++ b/integrations/google_vertex/CHANGELOG.md @@ -1,5 +1,15 @@ # Changelog +## [integrations/google_vertex-v3.0.0] - 2024-11-14 + +### 🐛 Bug Fixes + +- VertexAIGeminiGenerator - remove support for tools and change output type (#1180) + +### ⚙️ Miscellaneous Tasks + +- Fix Vertex tests (#1163) + ## [integrations/google_vertex-v2.2.0] - 2024-10-23 ### 🐛 Bug Fixes From 3c04cfec2f71bca26007612803e174e0b2eb3cd9 Mon Sep 17 00:00:00 2001 From: Sebastian Husch Lee Date: Thu, 14 Nov 2024 14:53:54 +0100 Subject: [PATCH 077/229] fix: Fixes to NvidiaRanker (#1191) * Fixes to NvidiaRanker * Add inits and headers * More headers * updates * Reactivate test * Fix tests * Reenable test and add test --- integrations/nvidia/pyproject.toml | 2 +- .../src/haystack_integrations/__init__.py | 3 ++ .../components/__init__.py | 3 ++ .../components/embedders/__init__.py | 3 ++ .../components/embedders/nvidia/__init__.py | 4 ++ .../embedders/nvidia/document_embedder.py | 11 +++-- .../embedders/nvidia/text_embedder.py | 7 ++- .../components/embedders/nvidia/truncate.py | 4 ++ .../components/generators/__init__.py | 3 ++ .../components/generators/nvidia/__init__.py | 1 + .../components/generators/nvidia/generator.py | 1 + .../components/rankers/__init__.py | 3 ++ .../components/rankers/nvidia/__init__.py | 4 ++ .../components/rankers/nvidia/ranker.py | 27 ++++++---- .../components/rankers/nvidia/truncate.py | 4 ++ .../haystack_integrations/utils/__init__.py | 3 ++ .../utils/nvidia/__init__.py | 4 ++ .../utils/nvidia/nim_backend.py | 4 ++ .../utils/nvidia/utils.py | 8 ++- integrations/nvidia/tests/__init__.py | 1 + integrations/nvidia/tests/conftest.py | 4 ++ integrations/nvidia/tests/test_base_url.py | 4 ++ .../nvidia/tests/test_document_embedder.py | 27 ++++++++-- .../tests/test_embedding_truncate_mode.py | 4 ++ integrations/nvidia/tests/test_generator.py | 1 + integrations/nvidia/tests/test_ranker.py | 49 +++++++++++++++++++ .../nvidia/tests/test_text_embedder.py | 21 +++++++- 27 files changed, 188 insertions(+), 22 deletions(-) create mode 100644 integrations/nvidia/src/haystack_integrations/__init__.py create mode 100644 integrations/nvidia/src/haystack_integrations/components/__init__.py create mode 100644 integrations/nvidia/src/haystack_integrations/components/embedders/__init__.py create mode 100644 integrations/nvidia/src/haystack_integrations/components/generators/__init__.py create mode 100644 integrations/nvidia/src/haystack_integrations/components/rankers/__init__.py create mode 100644 integrations/nvidia/src/haystack_integrations/utils/__init__.py diff --git a/integrations/nvidia/pyproject.toml b/integrations/nvidia/pyproject.toml index 7f0048c1b..586b50848 100644 --- a/integrations/nvidia/pyproject.toml +++ b/integrations/nvidia/pyproject.toml @@ -23,7 +23,7 @@ classifiers = [ "Programming Language :: Python :: Implementation :: CPython", "Programming Language :: Python :: Implementation :: PyPy", ] -dependencies = ["haystack-ai", "requests"] +dependencies = ["haystack-ai", "requests", "tqdm"] [project.urls] Documentation = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/nvidia#readme" diff --git a/integrations/nvidia/src/haystack_integrations/__init__.py b/integrations/nvidia/src/haystack_integrations/__init__.py new file mode 100644 index 000000000..6b5e14dc1 --- /dev/null +++ b/integrations/nvidia/src/haystack_integrations/__init__.py @@ -0,0 +1,3 @@ +# SPDX-FileCopyrightText: 2024-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 diff --git a/integrations/nvidia/src/haystack_integrations/components/__init__.py b/integrations/nvidia/src/haystack_integrations/components/__init__.py new file mode 100644 index 000000000..6b5e14dc1 --- /dev/null +++ b/integrations/nvidia/src/haystack_integrations/components/__init__.py @@ -0,0 +1,3 @@ +# SPDX-FileCopyrightText: 2024-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 diff --git a/integrations/nvidia/src/haystack_integrations/components/embedders/__init__.py b/integrations/nvidia/src/haystack_integrations/components/embedders/__init__.py new file mode 100644 index 000000000..6b5e14dc1 --- /dev/null +++ b/integrations/nvidia/src/haystack_integrations/components/embedders/__init__.py @@ -0,0 +1,3 @@ +# SPDX-FileCopyrightText: 2024-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 diff --git a/integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/__init__.py b/integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/__init__.py index bc2d9372c..827ad7dc6 100644 --- a/integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/__init__.py +++ b/integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/__init__.py @@ -1,3 +1,7 @@ +# SPDX-FileCopyrightText: 2024-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + from .document_embedder import NvidiaDocumentEmbedder from .text_embedder import NvidiaTextEmbedder from .truncate import EmbeddingTruncateMode diff --git a/integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/document_embedder.py b/integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/document_embedder.py index d746a75f4..606ec78fd 100644 --- a/integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/document_embedder.py +++ b/integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/document_embedder.py @@ -1,3 +1,7 @@ +# SPDX-FileCopyrightText: 2024-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + import warnings from typing import Any, Dict, List, Optional, Tuple, Union @@ -5,10 +9,9 @@ from haystack.utils import Secret, deserialize_secrets_inplace from tqdm import tqdm +from haystack_integrations.components.embedders.nvidia.truncate import EmbeddingTruncateMode from haystack_integrations.utils.nvidia import NimBackend, is_hosted, url_validation -from .truncate import EmbeddingTruncateMode - _DEFAULT_API_URL = "https://ai.api.nvidia.com/v1/retrieval/nvidia" @@ -167,7 +170,9 @@ def from_dict(cls, data: Dict[str, Any]) -> "NvidiaDocumentEmbedder": :returns: The deserialized component. """ - deserialize_secrets_inplace(data["init_parameters"], keys=["api_key"]) + init_parameters = data.get("init_parameters", {}) + if init_parameters: + deserialize_secrets_inplace(data["init_parameters"], keys=["api_key"]) return default_from_dict(cls, data) def _prepare_texts_to_embed(self, documents: List[Document]) -> List[str]: diff --git a/integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/text_embedder.py b/integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/text_embedder.py index 22bed8197..4b7072f33 100644 --- a/integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/text_embedder.py +++ b/integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/text_embedder.py @@ -1,13 +1,16 @@ +# SPDX-FileCopyrightText: 2024-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + import warnings from typing import Any, Dict, List, Optional, Union from haystack import component, default_from_dict, default_to_dict from haystack.utils import Secret, deserialize_secrets_inplace +from haystack_integrations.components.embedders.nvidia.truncate import EmbeddingTruncateMode from haystack_integrations.utils.nvidia import NimBackend, is_hosted, url_validation -from .truncate import EmbeddingTruncateMode - _DEFAULT_API_URL = "https://ai.api.nvidia.com/v1/retrieval/nvidia" diff --git a/integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/truncate.py b/integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/truncate.py index 3a8eb9d07..931c3cce3 100644 --- a/integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/truncate.py +++ b/integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/truncate.py @@ -1,3 +1,7 @@ +# SPDX-FileCopyrightText: 2024-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + from enum import Enum diff --git a/integrations/nvidia/src/haystack_integrations/components/generators/__init__.py b/integrations/nvidia/src/haystack_integrations/components/generators/__init__.py new file mode 100644 index 000000000..6b5e14dc1 --- /dev/null +++ b/integrations/nvidia/src/haystack_integrations/components/generators/__init__.py @@ -0,0 +1,3 @@ +# SPDX-FileCopyrightText: 2024-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 diff --git a/integrations/nvidia/src/haystack_integrations/components/generators/nvidia/__init__.py b/integrations/nvidia/src/haystack_integrations/components/generators/nvidia/__init__.py index 18354ea17..b809d83b9 100644 --- a/integrations/nvidia/src/haystack_integrations/components/generators/nvidia/__init__.py +++ b/integrations/nvidia/src/haystack_integrations/components/generators/nvidia/__init__.py @@ -1,6 +1,7 @@ # SPDX-FileCopyrightText: 2024-present deepset GmbH # # SPDX-License-Identifier: Apache-2.0 + from .generator import NvidiaGenerator __all__ = ["NvidiaGenerator"] diff --git a/integrations/nvidia/src/haystack_integrations/components/generators/nvidia/generator.py b/integrations/nvidia/src/haystack_integrations/components/generators/nvidia/generator.py index 3eadcc5df..5bf71a9e1 100644 --- a/integrations/nvidia/src/haystack_integrations/components/generators/nvidia/generator.py +++ b/integrations/nvidia/src/haystack_integrations/components/generators/nvidia/generator.py @@ -1,6 +1,7 @@ # SPDX-FileCopyrightText: 2024-present deepset GmbH # # SPDX-License-Identifier: Apache-2.0 + import warnings from typing import Any, Dict, List, Optional diff --git a/integrations/nvidia/src/haystack_integrations/components/rankers/__init__.py b/integrations/nvidia/src/haystack_integrations/components/rankers/__init__.py new file mode 100644 index 000000000..6b5e14dc1 --- /dev/null +++ b/integrations/nvidia/src/haystack_integrations/components/rankers/__init__.py @@ -0,0 +1,3 @@ +# SPDX-FileCopyrightText: 2024-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 diff --git a/integrations/nvidia/src/haystack_integrations/components/rankers/nvidia/__init__.py b/integrations/nvidia/src/haystack_integrations/components/rankers/nvidia/__init__.py index 29cb2f7f5..05daa1c54 100644 --- a/integrations/nvidia/src/haystack_integrations/components/rankers/nvidia/__init__.py +++ b/integrations/nvidia/src/haystack_integrations/components/rankers/nvidia/__init__.py @@ -1,3 +1,7 @@ +# SPDX-FileCopyrightText: 2024-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + from .ranker import NvidiaRanker __all__ = ["NvidiaRanker"] diff --git a/integrations/nvidia/src/haystack_integrations/components/rankers/nvidia/ranker.py b/integrations/nvidia/src/haystack_integrations/components/rankers/nvidia/ranker.py index 1553d1ac3..9938b37d1 100644 --- a/integrations/nvidia/src/haystack_integrations/components/rankers/nvidia/ranker.py +++ b/integrations/nvidia/src/haystack_integrations/components/rankers/nvidia/ranker.py @@ -1,12 +1,17 @@ +# SPDX-FileCopyrightText: 2024-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + import warnings from typing import Any, Dict, List, Optional, Union -from haystack import Document, component, default_from_dict, default_to_dict +from haystack import Document, component, default_from_dict, default_to_dict, logging from haystack.utils import Secret, deserialize_secrets_inplace +from haystack_integrations.components.rankers.nvidia.truncate import RankerTruncateMode from haystack_integrations.utils.nvidia import NimBackend, url_validation -from .truncate import RankerTruncateMode +logger = logging.getLogger(__name__) _DEFAULT_MODEL = "nvidia/nv-rerankqa-mistral-4b-v3" @@ -51,7 +56,7 @@ def __init__( model: Optional[str] = None, truncate: Optional[Union[RankerTruncateMode, str]] = None, api_url: Optional[str] = None, - api_key: Optional[Secret] = None, + api_key: Optional[Secret] = Secret.from_env_var("NVIDIA_API_KEY"), top_k: int = 5, ): """ @@ -100,6 +105,7 @@ def __init__( self._api_key = Secret.from_env_var("NVIDIA_API_KEY") self._top_k = top_k self._initialized = False + self._backend: Optional[Any] = None def to_dict(self) -> Dict[str, Any]: """ @@ -113,7 +119,7 @@ def to_dict(self) -> Dict[str, Any]: top_k=self._top_k, truncate=self._truncate, api_url=self._api_url, - api_key=self._api_key, + api_key=self._api_key.to_dict() if self._api_key else None, ) @classmethod @@ -124,7 +130,9 @@ def from_dict(cls, data: Dict[str, Any]) -> "NvidiaRanker": :param data: A dictionary containing the ranker's attributes. :returns: The deserialized ranker. """ - deserialize_secrets_inplace(data, keys=["api_key"]) + init_parameters = data.get("init_parameters", {}) + if init_parameters: + deserialize_secrets_inplace(data["init_parameters"], keys=["api_key"]) return default_from_dict(cls, data) def warm_up(self): @@ -170,16 +178,16 @@ def run( msg = "The ranker has not been loaded. Please call warm_up() before running." raise RuntimeError(msg) if not isinstance(query, str): - msg = "Ranker expects the `query` parameter to be a string." + msg = "NvidiaRanker expects the `query` parameter to be a string." raise TypeError(msg) if not isinstance(documents, list): - msg = "Ranker expects the `documents` parameter to be a list." + msg = "NvidiaRanker expects the `documents` parameter to be a list." raise TypeError(msg) if not all(isinstance(doc, Document) for doc in documents): - msg = "Ranker expects the `documents` parameter to be a list of Document objects." + msg = "NvidiaRanker expects the `documents` parameter to be a list of Document objects." raise TypeError(msg) if top_k is not None and not isinstance(top_k, int): - msg = "Ranker expects the `top_k` parameter to be an integer." + msg = "NvidiaRanker expects the `top_k` parameter to be an integer." raise TypeError(msg) if len(documents) == 0: @@ -187,6 +195,7 @@ def run( top_k = top_k if top_k is not None else self._top_k if top_k < 1: + logger.warning("top_k should be at least 1, returning nothing") warnings.warn("top_k should be at least 1, returning nothing", stacklevel=2) return {"documents": []} diff --git a/integrations/nvidia/src/haystack_integrations/components/rankers/nvidia/truncate.py b/integrations/nvidia/src/haystack_integrations/components/rankers/nvidia/truncate.py index 3b5d7f40a..649ceaf9d 100644 --- a/integrations/nvidia/src/haystack_integrations/components/rankers/nvidia/truncate.py +++ b/integrations/nvidia/src/haystack_integrations/components/rankers/nvidia/truncate.py @@ -1,3 +1,7 @@ +# SPDX-FileCopyrightText: 2024-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + from enum import Enum diff --git a/integrations/nvidia/src/haystack_integrations/utils/__init__.py b/integrations/nvidia/src/haystack_integrations/utils/__init__.py new file mode 100644 index 000000000..6b5e14dc1 --- /dev/null +++ b/integrations/nvidia/src/haystack_integrations/utils/__init__.py @@ -0,0 +1,3 @@ +# SPDX-FileCopyrightText: 2024-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 diff --git a/integrations/nvidia/src/haystack_integrations/utils/nvidia/__init__.py b/integrations/nvidia/src/haystack_integrations/utils/nvidia/__init__.py index da301d29d..f08cda6cd 100644 --- a/integrations/nvidia/src/haystack_integrations/utils/nvidia/__init__.py +++ b/integrations/nvidia/src/haystack_integrations/utils/nvidia/__init__.py @@ -1,3 +1,7 @@ +# SPDX-FileCopyrightText: 2024-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + from .nim_backend import Model, NimBackend from .utils import is_hosted, url_validation diff --git a/integrations/nvidia/src/haystack_integrations/utils/nvidia/nim_backend.py b/integrations/nvidia/src/haystack_integrations/utils/nvidia/nim_backend.py index cbb6b7c3f..0279cf608 100644 --- a/integrations/nvidia/src/haystack_integrations/utils/nvidia/nim_backend.py +++ b/integrations/nvidia/src/haystack_integrations/utils/nvidia/nim_backend.py @@ -1,3 +1,7 @@ +# SPDX-FileCopyrightText: 2024-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + from dataclasses import dataclass, field from typing import Any, Dict, List, Optional, Tuple diff --git a/integrations/nvidia/src/haystack_integrations/utils/nvidia/utils.py b/integrations/nvidia/src/haystack_integrations/utils/nvidia/utils.py index 7d4dfc3b4..f07989405 100644 --- a/integrations/nvidia/src/haystack_integrations/utils/nvidia/utils.py +++ b/integrations/nvidia/src/haystack_integrations/utils/nvidia/utils.py @@ -1,9 +1,13 @@ +# SPDX-FileCopyrightText: 2024-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + import warnings -from typing import List +from typing import List, Optional from urllib.parse import urlparse, urlunparse -def url_validation(api_url: str, default_api_url: str, allowed_paths: List[str]) -> str: +def url_validation(api_url: str, default_api_url: Optional[str], allowed_paths: List[str]) -> str: """ Validate and normalize an API URL. diff --git a/integrations/nvidia/tests/__init__.py b/integrations/nvidia/tests/__init__.py index 47611e0b9..38adc654d 100644 --- a/integrations/nvidia/tests/__init__.py +++ b/integrations/nvidia/tests/__init__.py @@ -1,6 +1,7 @@ # SPDX-FileCopyrightText: 2023-present deepset GmbH # # SPDX-License-Identifier: Apache-2.0 + from .conftest import MockBackend __all__ = ["MockBackend"] diff --git a/integrations/nvidia/tests/conftest.py b/integrations/nvidia/tests/conftest.py index a6c78ba4e..b6346c672 100644 --- a/integrations/nvidia/tests/conftest.py +++ b/integrations/nvidia/tests/conftest.py @@ -1,3 +1,7 @@ +# SPDX-FileCopyrightText: 2024-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + from typing import Any, Dict, List, Optional, Tuple import pytest diff --git a/integrations/nvidia/tests/test_base_url.py b/integrations/nvidia/tests/test_base_url.py index 426bacc25..506fbc385 100644 --- a/integrations/nvidia/tests/test_base_url.py +++ b/integrations/nvidia/tests/test_base_url.py @@ -1,3 +1,7 @@ +# SPDX-FileCopyrightText: 2024-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + import pytest from haystack_integrations.components.embedders.nvidia import NvidiaDocumentEmbedder, NvidiaTextEmbedder diff --git a/integrations/nvidia/tests/test_document_embedder.py b/integrations/nvidia/tests/test_document_embedder.py index db69053e7..7e0e02f3d 100644 --- a/integrations/nvidia/tests/test_document_embedder.py +++ b/integrations/nvidia/tests/test_document_embedder.py @@ -1,3 +1,7 @@ +# SPDX-FileCopyrightText: 2024-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + import os import pytest @@ -104,7 +108,7 @@ def test_to_dict_with_custom_init_parameters(self, monkeypatch): }, } - def from_dict(self, monkeypatch): + def test_from_dict(self, monkeypatch): monkeypatch.setenv("NVIDIA_API_KEY", "fake-api-key") data = { "type": "haystack_integrations.components.embedders.nvidia.document_embedder.NvidiaDocumentEmbedder", @@ -122,15 +126,32 @@ def from_dict(self, monkeypatch): }, } component = NvidiaDocumentEmbedder.from_dict(data) - assert component.model == "nvolveqa_40k" + assert component.model == "playground_nvolveqa_40k" assert component.api_url == "https://example.com/v1" assert component.prefix == "prefix" assert component.suffix == "suffix" + assert component.batch_size == 10 + assert component.progress_bar is False + assert component.meta_fields_to_embed == ["test_field"] + assert component.embedding_separator == " | " + assert component.truncate == EmbeddingTruncateMode.START + + def test_from_dict_defaults(self, monkeypatch): + monkeypatch.setenv("NVIDIA_API_KEY", "fake-api-key") + data = { + "type": "haystack_integrations.components.embedders.nvidia.document_embedder.NvidiaDocumentEmbedder", + "init_parameters": {}, + } + component = NvidiaDocumentEmbedder.from_dict(data) + assert component.model == "nvidia/nv-embedqa-e5-v5" + assert component.api_url == "https://ai.api.nvidia.com/v1/retrieval/nvidia" + assert component.prefix == "" + assert component.suffix == "" assert component.batch_size == 32 assert component.progress_bar assert component.meta_fields_to_embed == [] assert component.embedding_separator == "\n" - assert component.truncate == EmbeddingTruncateMode.START + assert component.truncate is None def test_prepare_texts_to_embed_w_metadata(self): documents = [ diff --git a/integrations/nvidia/tests/test_embedding_truncate_mode.py b/integrations/nvidia/tests/test_embedding_truncate_mode.py index e74d0308c..16f9112ea 100644 --- a/integrations/nvidia/tests/test_embedding_truncate_mode.py +++ b/integrations/nvidia/tests/test_embedding_truncate_mode.py @@ -1,3 +1,7 @@ +# SPDX-FileCopyrightText: 2024-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + import pytest from haystack_integrations.components.embedders.nvidia import EmbeddingTruncateMode diff --git a/integrations/nvidia/tests/test_generator.py b/integrations/nvidia/tests/test_generator.py index 0bd8b1fc6..055830ae5 100644 --- a/integrations/nvidia/tests/test_generator.py +++ b/integrations/nvidia/tests/test_generator.py @@ -1,6 +1,7 @@ # SPDX-FileCopyrightText: 2024-present deepset GmbH # # SPDX-License-Identifier: Apache-2.0 + import os import pytest diff --git a/integrations/nvidia/tests/test_ranker.py b/integrations/nvidia/tests/test_ranker.py index 566fd18a8..d66bb0f65 100644 --- a/integrations/nvidia/tests/test_ranker.py +++ b/integrations/nvidia/tests/test_ranker.py @@ -1,3 +1,7 @@ +# SPDX-FileCopyrightText: 2024-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + import os import re from typing import Any, Optional, Union @@ -256,3 +260,48 @@ def test_warm_up_once(self, monkeypatch) -> None: backend = client._backend client.warm_up() assert backend == client._backend + + def test_to_dict(self) -> None: + client = NvidiaRanker() + assert client.to_dict() == { + "type": "haystack_integrations.components.rankers.nvidia.ranker.NvidiaRanker", + "init_parameters": { + "model": "nvidia/nv-rerankqa-mistral-4b-v3", + "top_k": 5, + "truncate": None, + "api_url": None, + "api_key": {"type": "env_var", "env_vars": ["NVIDIA_API_KEY"], "strict": True}, + }, + } + + def test_from_dict(self) -> None: + client = NvidiaRanker.from_dict( + { + "type": "haystack_integrations.components.rankers.nvidia.ranker.NvidiaRanker", + "init_parameters": { + "model": "nvidia/nv-rerankqa-mistral-4b-v3", + "top_k": 5, + "truncate": None, + "api_url": None, + "api_key": {"type": "env_var", "env_vars": ["NVIDIA_API_KEY"], "strict": True}, + }, + } + ) + assert client._model == "nvidia/nv-rerankqa-mistral-4b-v3" + assert client._top_k == 5 + assert client._truncate is None + assert client._api_url is None + assert client._api_key == Secret.from_env_var("NVIDIA_API_KEY") + + def test_from_dict_defaults(self) -> None: + client = NvidiaRanker.from_dict( + { + "type": "haystack_integrations.components.rankers.nvidia.ranker.NvidiaRanker", + "init_parameters": {}, + } + ) + assert client._model == "nvidia/nv-rerankqa-mistral-4b-v3" + assert client._top_k == 5 + assert client._truncate is None + assert client._api_url is None + assert client._api_key == Secret.from_env_var("NVIDIA_API_KEY") diff --git a/integrations/nvidia/tests/test_text_embedder.py b/integrations/nvidia/tests/test_text_embedder.py index 8690de6b1..278fa5191 100644 --- a/integrations/nvidia/tests/test_text_embedder.py +++ b/integrations/nvidia/tests/test_text_embedder.py @@ -1,3 +1,7 @@ +# SPDX-FileCopyrightText: 2024-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + import os import pytest @@ -77,7 +81,7 @@ def test_to_dict_with_custom_init_parameters(self, monkeypatch): }, } - def from_dict(self, monkeypatch): + def test_from_dict(self, monkeypatch): monkeypatch.setenv("NVIDIA_API_KEY", "fake-api-key") data = { "type": "haystack_integrations.components.embedders.nvidia.text_embedder.NvidiaTextEmbedder", @@ -95,7 +99,20 @@ def from_dict(self, monkeypatch): assert component.api_url == "https://example.com/v1" assert component.prefix == "prefix" assert component.suffix == "suffix" - assert component.truncate == "START" + assert component.truncate == EmbeddingTruncateMode.START + + def test_from_dict_defaults(self, monkeypatch): + monkeypatch.setenv("NVIDIA_API_KEY", "fake-api-key") + data = { + "type": "haystack_integrations.components.embedders.nvidia.text_embedder.NvidiaTextEmbedder", + "init_parameters": {}, + } + component = NvidiaTextEmbedder.from_dict(data) + assert component.model == "nvidia/nv-embedqa-e5-v5" + assert component.api_url == "https://ai.api.nvidia.com/v1/retrieval/nvidia" + assert component.prefix == "" + assert component.suffix == "" + assert component.truncate is None @pytest.mark.usefixtures("mock_local_models") def test_run_default_model(self): From 1ef03c07b5eddd3be819f11b4ea2b40251f9d079 Mon Sep 17 00:00:00 2001 From: HaystackBot Date: Thu, 14 Nov 2024 14:00:34 +0000 Subject: [PATCH 078/229] Update the changelog --- integrations/nvidia/CHANGELOG.md | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/integrations/nvidia/CHANGELOG.md b/integrations/nvidia/CHANGELOG.md index 75b31d033..a536e431d 100644 --- a/integrations/nvidia/CHANGELOG.md +++ b/integrations/nvidia/CHANGELOG.md @@ -1,5 +1,11 @@ # Changelog +## [integrations/nvidia-v0.1.1] - 2024-11-14 + +### 🐛 Bug Fixes + +- Fixes to NvidiaRanker (#1191) + ## [integrations/nvidia-v0.1.0] - 2024-11-13 ### 🚀 Features From 3b3395845cd66259ba27f27e65be5a87efc9f98f Mon Sep 17 00:00:00 2001 From: rblst Date: Thu, 14 Nov 2024 19:09:24 +0100 Subject: [PATCH 079/229] feat: Add schema support to pgvector document store. (#1095) * Add schema support for the pgvector document store. Using the public schema of a PostgreSQL database is an anti-pattern. This change adds support for using a schema other than the public schema to create tables. * Fix long lines. * Fix long lines. Remove trailing spaces. * Fix trailing spaces. * Fix last trailing space. * Fix ruff issues. * Fix trailing space. * small fixes --------- Co-authored-by: Stefano Fiorucci --- .../pgvector/document_store.py | 70 +++++++++++++------ .../pgvector/tests/test_document_store.py | 3 + .../pgvector/tests/test_retrievers.py | 2 + 3 files changed, 54 insertions(+), 21 deletions(-) diff --git a/integrations/pgvector/src/haystack_integrations/document_stores/pgvector/document_store.py b/integrations/pgvector/src/haystack_integrations/document_stores/pgvector/document_store.py index 1b1333f5c..8e9c0f2fc 100644 --- a/integrations/pgvector/src/haystack_integrations/document_stores/pgvector/document_store.py +++ b/integrations/pgvector/src/haystack_integrations/document_stores/pgvector/document_store.py @@ -24,7 +24,7 @@ logger = logging.getLogger(__name__) CREATE_TABLE_STATEMENT = """ -CREATE TABLE IF NOT EXISTS {table_name} ( +CREATE TABLE IF NOT EXISTS {schema_name}.{table_name} ( id VARCHAR(128) PRIMARY KEY, embedding VECTOR({embedding_dimension}), content TEXT, @@ -36,7 +36,7 @@ """ INSERT_STATEMENT = """ -INSERT INTO {table_name} +INSERT INTO {schema_name}.{table_name} (id, embedding, content, dataframe, blob_data, blob_meta, blob_mime_type, meta) VALUES (%(id)s, %(embedding)s, %(content)s, %(dataframe)s, %(blob_data)s, %(blob_meta)s, %(blob_mime_type)s, %(meta)s) """ @@ -54,7 +54,7 @@ KEYWORD_QUERY = """ SELECT {table_name}.*, ts_rank_cd(to_tsvector({language}, content), query) AS score -FROM {table_name}, plainto_tsquery({language}, %s) query +FROM {schema_name}.{table_name}, plainto_tsquery({language}, %s) query WHERE to_tsvector({language}, content) @@ query """ @@ -78,6 +78,7 @@ def __init__( self, *, connection_string: Secret = Secret.from_env_var("PG_CONN_STR"), + schema_name: str = "public", table_name: str = "haystack_documents", language: str = "english", embedding_dimension: int = 768, @@ -101,6 +102,7 @@ def __init__( e.g.: `PG_CONN_STR="host=HOST port=PORT dbname=DBNAME user=USER password=PASSWORD"` See [PostgreSQL Documentation](https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-CONNSTRING) for more details. + :param schema_name: The name of the schema the table is created in. The schema must already exist. :param table_name: The name of the table to use to store Haystack documents. :param language: The language to be used to parse query and document content in keyword retrieval. To see the list of available languages, you can run the following SQL query in your PostgreSQL database: @@ -137,6 +139,7 @@ def __init__( self.connection_string = connection_string self.table_name = table_name + self.schema_name = schema_name self.embedding_dimension = embedding_dimension if vector_function not in VALID_VECTOR_FUNCTIONS: msg = f"vector_function must be one of {VALID_VECTOR_FUNCTIONS}, but got {vector_function}" @@ -207,6 +210,7 @@ def to_dict(self) -> Dict[str, Any]: return default_to_dict( self, connection_string=self.connection_string.to_dict(), + schema_name=self.schema_name, table_name=self.table_name, embedding_dimension=self.embedding_dimension, vector_function=self.vector_function, @@ -266,7 +270,9 @@ def _create_table_if_not_exists(self): """ create_sql = SQL(CREATE_TABLE_STATEMENT).format( - table_name=Identifier(self.table_name), embedding_dimension=SQLLiteral(self.embedding_dimension) + schema_name=Identifier(self.schema_name), + table_name=Identifier(self.table_name), + embedding_dimension=SQLLiteral(self.embedding_dimension), ) self._execute_sql(create_sql, error_msg="Could not create table in PgvectorDocumentStore") @@ -274,12 +280,18 @@ def _create_table_if_not_exists(self): def delete_table(self): """ Deletes the table used to store Haystack documents. - The name of the table (`table_name`) is defined when initializing the `PgvectorDocumentStore`. + The name of the schema (`schema_name`) and the name of the table (`table_name`) + are defined when initializing the `PgvectorDocumentStore`. """ + delete_sql = SQL("DROP TABLE IF EXISTS {schema_name}.{table_name}").format( + schema_name=Identifier(self.schema_name), + table_name=Identifier(self.table_name), + ) - delete_sql = SQL("DROP TABLE IF EXISTS {table_name}").format(table_name=Identifier(self.table_name)) - - self._execute_sql(delete_sql, error_msg=f"Could not delete table {self.table_name} in PgvectorDocumentStore") + self._execute_sql( + delete_sql, + error_msg=f"Could not delete table {self.schema_name}.{self.table_name} in PgvectorDocumentStore", + ) def _create_keyword_index_if_not_exists(self): """ @@ -287,15 +299,16 @@ def _create_keyword_index_if_not_exists(self): """ index_exists = bool( self._execute_sql( - "SELECT 1 FROM pg_indexes WHERE tablename = %s AND indexname = %s", - (self.table_name, self.keyword_index_name), + "SELECT 1 FROM pg_indexes WHERE schemaname = %s AND tablename = %s AND indexname = %s", + (self.schema_name, self.table_name, self.keyword_index_name), "Could not check if keyword index exists", ).fetchone() ) sql_create_index = SQL( - "CREATE INDEX {index_name} ON {table_name} USING GIN (to_tsvector({language}, content))" + "CREATE INDEX {index_name} ON {schema_name}.{table_name} USING GIN (to_tsvector({language}, content))" ).format( + schema_name=Identifier(self.schema_name), index_name=Identifier(self.keyword_index_name), table_name=Identifier(self.table_name), language=SQLLiteral(self.language), @@ -318,8 +331,8 @@ def _handle_hnsw(self): index_exists = bool( self._execute_sql( - "SELECT 1 FROM pg_indexes WHERE tablename = %s AND indexname = %s", - (self.table_name, self.hnsw_index_name), + "SELECT 1 FROM pg_indexes WHERE schemaname = %s AND tablename = %s AND indexname = %s", + (self.schema_name, self.table_name, self.hnsw_index_name), "Could not check if HNSW index exists", ).fetchone() ) @@ -349,8 +362,13 @@ def _create_hnsw_index(self): if key in HNSW_INDEX_CREATION_VALID_KWARGS } - sql_create_index = SQL("CREATE INDEX {index_name} ON {table_name} USING hnsw (embedding {ops}) ").format( - index_name=Identifier(self.hnsw_index_name), table_name=Identifier(self.table_name), ops=SQL(pg_ops) + sql_create_index = SQL( + "CREATE INDEX {index_name} ON {schema_name}.{table_name} USING hnsw (embedding {ops}) " + ).format( + schema_name=Identifier(self.schema_name), + index_name=Identifier(self.hnsw_index_name), + table_name=Identifier(self.table_name), + ops=SQL(pg_ops), ) if actual_hnsw_index_creation_kwargs: @@ -369,7 +387,9 @@ def count_documents(self) -> int: Returns how many documents are present in the document store. """ - sql_count = SQL("SELECT COUNT(*) FROM {table_name}").format(table_name=Identifier(self.table_name)) + sql_count = SQL("SELECT COUNT(*) FROM {schema_name}.{table_name}").format( + schema_name=Identifier(self.schema_name), table_name=Identifier(self.table_name) + ) count = self._execute_sql(sql_count, error_msg="Could not count documents in PgvectorDocumentStore").fetchone()[ 0 @@ -395,7 +415,9 @@ def filter_documents(self, filters: Optional[Dict[str, Any]] = None) -> List[Doc msg = "Invalid filter syntax. See https://docs.haystack.deepset.ai/docs/metadata-filtering for details." raise ValueError(msg) - sql_filter = SQL("SELECT * FROM {table_name}").format(table_name=Identifier(self.table_name)) + sql_filter = SQL("SELECT * FROM {schema_name}.{table_name}").format( + schema_name=Identifier(self.schema_name), table_name=Identifier(self.table_name) + ) params = () if filters: @@ -434,7 +456,9 @@ def write_documents(self, documents: List[Document], policy: DuplicatePolicy = D db_documents = self._from_haystack_to_pg_documents(documents) - sql_insert = SQL(INSERT_STATEMENT).format(table_name=Identifier(self.table_name)) + sql_insert = SQL(INSERT_STATEMENT).format( + schema_name=Identifier(self.schema_name), table_name=Identifier(self.table_name) + ) if policy == DuplicatePolicy.OVERWRITE: sql_insert += SQL(UPDATE_STATEMENT) @@ -543,8 +567,10 @@ def delete_documents(self, document_ids: List[str]) -> None: document_ids_str = ", ".join(f"'{document_id}'" for document_id in document_ids) - delete_sql = SQL("DELETE FROM {table_name} WHERE id IN ({document_ids_str})").format( - table_name=Identifier(self.table_name), document_ids_str=SQL(document_ids_str) + delete_sql = SQL("DELETE FROM {schema_name}.{table_name} WHERE id IN ({document_ids_str})").format( + schema_name=Identifier(self.schema_name), + table_name=Identifier(self.table_name), + document_ids_str=SQL(document_ids_str), ) self._execute_sql(delete_sql, error_msg="Could not delete documents from PgvectorDocumentStore") @@ -570,6 +596,7 @@ def _keyword_retrieval( raise ValueError(msg) sql_select = SQL(KEYWORD_QUERY).format( + schema_name=Identifier(self.schema_name), table_name=Identifier(self.table_name), language=SQLLiteral(self.language), query=SQLLiteral(query), @@ -643,7 +670,8 @@ def _embedding_retrieval( elif vector_function == "l2_distance": score_definition = f"embedding <-> {query_embedding_for_postgres} AS score" - sql_select = SQL("SELECT *, {score} FROM {table_name}").format( + sql_select = SQL("SELECT *, {score} FROM {schema_name}.{table_name}").format( + schema_name=Identifier(self.schema_name), table_name=Identifier(self.table_name), score=SQL(score_definition), ) diff --git a/integrations/pgvector/tests/test_document_store.py b/integrations/pgvector/tests/test_document_store.py index 93514b71c..4af4fc8de 100644 --- a/integrations/pgvector/tests/test_document_store.py +++ b/integrations/pgvector/tests/test_document_store.py @@ -47,6 +47,7 @@ def test_init(monkeypatch): monkeypatch.setenv("PG_CONN_STR", "some_connection_string") document_store = PgvectorDocumentStore( + schema_name="my_schema", table_name="my_table", embedding_dimension=512, vector_function="l2_distance", @@ -59,6 +60,7 @@ def test_init(monkeypatch): keyword_index_name="my_keyword_index", ) + assert document_store.schema_name == "my_schema" assert document_store.table_name == "my_table" assert document_store.embedding_dimension == 512 assert document_store.vector_function == "l2_distance" @@ -93,6 +95,7 @@ def test_to_dict(monkeypatch): "init_parameters": { "connection_string": {"env_vars": ["PG_CONN_STR"], "strict": True, "type": "env_var"}, "table_name": "my_table", + "schema_name": "public", "embedding_dimension": 512, "vector_function": "l2_distance", "recreate_table": True, diff --git a/integrations/pgvector/tests/test_retrievers.py b/integrations/pgvector/tests/test_retrievers.py index 290891307..4125c3e3a 100644 --- a/integrations/pgvector/tests/test_retrievers.py +++ b/integrations/pgvector/tests/test_retrievers.py @@ -50,6 +50,7 @@ def test_to_dict(self, mock_store): "type": "haystack_integrations.document_stores.pgvector.document_store.PgvectorDocumentStore", "init_parameters": { "connection_string": {"env_vars": ["PG_CONN_STR"], "strict": True, "type": "env_var"}, + "schema_name": "public", "table_name": "haystack", "embedding_dimension": 768, "vector_function": "cosine_similarity", @@ -175,6 +176,7 @@ def test_to_dict(self, mock_store): "type": "haystack_integrations.document_stores.pgvector.document_store.PgvectorDocumentStore", "init_parameters": { "connection_string": {"env_vars": ["PG_CONN_STR"], "strict": True, "type": "env_var"}, + "schema_name": "public", "table_name": "haystack", "embedding_dimension": 768, "vector_function": "cosine_similarity", From e21ce0cd2315d95cae514c0c9a49d15c62499210 Mon Sep 17 00:00:00 2001 From: Amna Mubashar Date: Fri, 15 Nov 2024 14:03:04 +0100 Subject: [PATCH 080/229] Add AnthropicVertexChatGenerator component (#1192) * Created a model adapter * Create adapter class and add VertexAPI * Add chat generator for Anthropic Vertex * Add tests * Small fix * Improve doc_strings * Make project_id and region mandatory params * Small fix --- .../generators/anthropic/__init__.py | 3 +- .../anthropic/chat/vertex_chat_generator.py | 135 ++++++++++++ .../tests/test_vertex_chat_generator.py | 197 ++++++++++++++++++ 3 files changed, 334 insertions(+), 1 deletion(-) create mode 100644 integrations/anthropic/src/haystack_integrations/components/generators/anthropic/chat/vertex_chat_generator.py create mode 100644 integrations/anthropic/tests/test_vertex_chat_generator.py diff --git a/integrations/anthropic/src/haystack_integrations/components/generators/anthropic/__init__.py b/integrations/anthropic/src/haystack_integrations/components/generators/anthropic/__init__.py index c2c1ee40d..0bd29898e 100644 --- a/integrations/anthropic/src/haystack_integrations/components/generators/anthropic/__init__.py +++ b/integrations/anthropic/src/haystack_integrations/components/generators/anthropic/__init__.py @@ -2,6 +2,7 @@ # # SPDX-License-Identifier: Apache-2.0 from .chat.chat_generator import AnthropicChatGenerator +from .chat.vertex_chat_generator import AnthropicVertexChatGenerator from .generator import AnthropicGenerator -__all__ = ["AnthropicGenerator", "AnthropicChatGenerator"] +__all__ = ["AnthropicGenerator", "AnthropicChatGenerator", "AnthropicVertexChatGenerator"] diff --git a/integrations/anthropic/src/haystack_integrations/components/generators/anthropic/chat/vertex_chat_generator.py b/integrations/anthropic/src/haystack_integrations/components/generators/anthropic/chat/vertex_chat_generator.py new file mode 100644 index 000000000..4ece944cd --- /dev/null +++ b/integrations/anthropic/src/haystack_integrations/components/generators/anthropic/chat/vertex_chat_generator.py @@ -0,0 +1,135 @@ +import os +from typing import Any, Callable, Dict, Optional + +from haystack import component, default_from_dict, default_to_dict, logging +from haystack.dataclasses import StreamingChunk +from haystack.utils import deserialize_callable, serialize_callable + +from anthropic import AnthropicVertex + +from .chat_generator import AnthropicChatGenerator + +logger = logging.getLogger(__name__) + + +@component +class AnthropicVertexChatGenerator(AnthropicChatGenerator): + """ + + Enables text generation using state-of-the-art Claude 3 LLMs via the Anthropic Vertex AI API. + It supports models such as `Claude 3.5 Sonnet`, `Claude 3 Opus`, `Claude 3 Sonnet`, and `Claude 3 Haiku`, + accessible through the Vertex AI API endpoint. + + To use AnthropicVertexChatGenerator, you must have a GCP project with Vertex AI enabled. + Additionally, ensure that the desired Anthropic model is activated in the Vertex AI Model Garden. + Before making requests, you may need to authenticate with GCP using `gcloud auth login`. + For more details, refer to the [guide] (https://docs.anthropic.com/en/api/claude-on-vertex-ai). + + Any valid text generation parameters for the Anthropic messaging API can be passed to + the AnthropicVertex API. Users can provide these parameters directly to the component via + the `generation_kwargs` parameter in `__init__` or the `run` method. + + For more details on the parameters supported by the Anthropic API, refer to the + Anthropic Message API [documentation](https://docs.anthropic.com/en/api/messages). + + ```python + from haystack_integrations.components.generators.anthropic import AnthropicVertexChatGenerator + from haystack.dataclasses import ChatMessage + + messages = [ChatMessage.from_user("What's Natural Language Processing?")] + client = AnthropicVertexChatGenerator( + model="claude-3-sonnet@20240229", + project_id="your-project-id", region="your-region" + ) + response = client.run(messages) + print(response) + + >> {'replies': [ChatMessage(content='Natural Language Processing (NLP) is a field of artificial intelligence that + >> focuses on enabling computers to understand, interpret, and generate human language. It involves developing + >> techniques and algorithms to analyze and process text or speech data, allowing machines to comprehend and + >> communicate in natural languages like English, Spanish, or Chinese.', role=, + >> name=None, meta={'model': 'claude-3-sonnet@20240229', 'index': 0, 'finish_reason': 'end_turn', + >> 'usage': {'input_tokens': 15, 'output_tokens': 64}})]} + ``` + + For more details on supported models and their capabilities, refer to the Anthropic + [documentation](https://docs.anthropic.com/claude/docs/intro-to-claude). + + """ + + def __init__( + self, + region: str, + project_id: str, + model: str = "claude-3-5-sonnet@20240620", + streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, + generation_kwargs: Optional[Dict[str, Any]] = None, + ignore_tools_thinking_messages: bool = True, + ): + """ + Creates an instance of AnthropicVertexChatGenerator. + + :param region: The region where the Anthropic model is deployed. Defaults to "us-central1". + :param project_id: The GCP project ID where the Anthropic model is deployed. + :param model: The name of the model to use. + :param streaming_callback: A callback function that is called when a new token is received from the stream. + The callback function accepts StreamingChunk as an argument. + :param generation_kwargs: Other parameters to use for the model. These parameters are all sent directly to + the AnthropicVertex endpoint. See Anthropic [documentation](https://docs.anthropic.com/claude/reference/messages_post) + for more details. + + Supported generation_kwargs parameters are: + - `system`: The system message to be passed to the model. + - `max_tokens`: The maximum number of tokens to generate. + - `metadata`: A dictionary of metadata to be passed to the model. + - `stop_sequences`: A list of strings that the model should stop generating at. + - `temperature`: The temperature to use for sampling. + - `top_p`: The top_p value to use for nucleus sampling. + - `top_k`: The top_k value to use for top-k sampling. + - `extra_headers`: A dictionary of extra headers to be passed to the model (i.e. for beta features). + :param ignore_tools_thinking_messages: Anthropic's approach to tools (function calling) resolution involves a + "chain of thought" messages before returning the actual function names and parameters in a message. If + `ignore_tools_thinking_messages` is `True`, the generator will drop so-called thinking messages when tool + use is detected. See the Anthropic [tools](https://docs.anthropic.com/en/docs/tool-use#chain-of-thought-tool-use) + for more details. + """ + self.region = region or os.environ.get("REGION") + self.project_id = project_id or os.environ.get("PROJECT_ID") + self.model = model + self.generation_kwargs = generation_kwargs or {} + self.streaming_callback = streaming_callback + self.client = AnthropicVertex(region=self.region, project_id=self.project_id) + self.ignore_tools_thinking_messages = ignore_tools_thinking_messages + + def to_dict(self) -> Dict[str, Any]: + """ + Serialize this component to a dictionary. + + :returns: + The serialized component as a dictionary. + """ + callback_name = serialize_callable(self.streaming_callback) if self.streaming_callback else None + return default_to_dict( + self, + region=self.region, + project_id=self.project_id, + model=self.model, + streaming_callback=callback_name, + generation_kwargs=self.generation_kwargs, + ignore_tools_thinking_messages=self.ignore_tools_thinking_messages, + ) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "AnthropicVertexChatGenerator": + """ + Deserialize this component from a dictionary. + + :param data: The dictionary representation of this component. + :returns: + The deserialized component instance. + """ + init_params = data.get("init_parameters", {}) + serialized_callback_handler = init_params.get("streaming_callback") + if serialized_callback_handler: + data["init_parameters"]["streaming_callback"] = deserialize_callable(serialized_callback_handler) + return default_from_dict(cls, data) diff --git a/integrations/anthropic/tests/test_vertex_chat_generator.py b/integrations/anthropic/tests/test_vertex_chat_generator.py new file mode 100644 index 000000000..a67e801ad --- /dev/null +++ b/integrations/anthropic/tests/test_vertex_chat_generator.py @@ -0,0 +1,197 @@ +import os + +import anthropic +import pytest +from haystack.components.generators.utils import print_streaming_chunk +from haystack.dataclasses import ChatMessage, ChatRole + +from haystack_integrations.components.generators.anthropic import AnthropicVertexChatGenerator + + +@pytest.fixture +def chat_messages(): + return [ + ChatMessage.from_system("\\nYou are a helpful assistant, be super brief in your responses."), + ChatMessage.from_user("What's the capital of France?"), + ] + + +class TestAnthropicVertexChatGenerator: + def test_init_default(self): + component = AnthropicVertexChatGenerator(region="us-central1", project_id="test-project-id") + assert component.region == "us-central1" + assert component.project_id == "test-project-id" + assert component.model == "claude-3-5-sonnet@20240620" + assert component.streaming_callback is None + assert not component.generation_kwargs + assert component.ignore_tools_thinking_messages + + def test_init_with_parameters(self): + component = AnthropicVertexChatGenerator( + region="us-central1", + project_id="test-project-id", + model="claude-3-5-sonnet@20240620", + streaming_callback=print_streaming_chunk, + generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"}, + ignore_tools_thinking_messages=False, + ) + assert component.region == "us-central1" + assert component.project_id == "test-project-id" + assert component.model == "claude-3-5-sonnet@20240620" + assert component.streaming_callback is print_streaming_chunk + assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"} + assert component.ignore_tools_thinking_messages is False + + def test_to_dict_default(self): + component = AnthropicVertexChatGenerator(region="us-central1", project_id="test-project-id") + data = component.to_dict() + assert data == { + "type": ( + "haystack_integrations.components.generators." + "anthropic.chat.vertex_chat_generator.AnthropicVertexChatGenerator" + ), + "init_parameters": { + "region": "us-central1", + "project_id": "test-project-id", + "model": "claude-3-5-sonnet@20240620", + "streaming_callback": None, + "generation_kwargs": {}, + "ignore_tools_thinking_messages": True, + }, + } + + def test_to_dict_with_parameters(self): + component = AnthropicVertexChatGenerator( + region="us-central1", + project_id="test-project-id", + streaming_callback=print_streaming_chunk, + generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"}, + ) + data = component.to_dict() + assert data == { + "type": ( + "haystack_integrations.components.generators." + "anthropic.chat.vertex_chat_generator.AnthropicVertexChatGenerator" + ), + "init_parameters": { + "region": "us-central1", + "project_id": "test-project-id", + "model": "claude-3-5-sonnet@20240620", + "streaming_callback": "haystack.components.generators.utils.print_streaming_chunk", + "generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"}, + "ignore_tools_thinking_messages": True, + }, + } + + def test_to_dict_with_lambda_streaming_callback(self): + component = AnthropicVertexChatGenerator( + region="us-central1", + project_id="test-project-id", + model="claude-3-5-sonnet@20240620", + streaming_callback=lambda x: x, + generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"}, + ) + data = component.to_dict() + assert data == { + "type": ( + "haystack_integrations.components.generators." + "anthropic.chat.vertex_chat_generator.AnthropicVertexChatGenerator" + ), + "init_parameters": { + "region": "us-central1", + "project_id": "test-project-id", + "model": "claude-3-5-sonnet@20240620", + "streaming_callback": "tests.test_vertex_chat_generator.", + "generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"}, + "ignore_tools_thinking_messages": True, + }, + } + + def test_from_dict(self): + data = { + "type": ( + "haystack_integrations.components.generators." + "anthropic.chat.vertex_chat_generator.AnthropicVertexChatGenerator" + ), + "init_parameters": { + "region": "us-central1", + "project_id": "test-project-id", + "model": "claude-3-5-sonnet@20240620", + "streaming_callback": "haystack.components.generators.utils.print_streaming_chunk", + "generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"}, + "ignore_tools_thinking_messages": True, + }, + } + component = AnthropicVertexChatGenerator.from_dict(data) + assert component.model == "claude-3-5-sonnet@20240620" + assert component.region == "us-central1" + assert component.project_id == "test-project-id" + assert component.streaming_callback is print_streaming_chunk + assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"} + + def test_run(self, chat_messages, mock_chat_completion): + component = AnthropicVertexChatGenerator(region="us-central1", project_id="test-project-id") + response = component.run(chat_messages) + + # check that the component returns the correct ChatMessage response + assert isinstance(response, dict) + assert "replies" in response + assert isinstance(response["replies"], list) + assert len(response["replies"]) == 1 + assert [isinstance(reply, ChatMessage) for reply in response["replies"]] + + def test_run_with_params(self, chat_messages, mock_chat_completion): + component = AnthropicVertexChatGenerator( + region="us-central1", project_id="test-project-id", generation_kwargs={"max_tokens": 10, "temperature": 0.5} + ) + response = component.run(chat_messages) + + # check that the component calls the Anthropic API with the correct parameters + _, kwargs = mock_chat_completion.call_args + assert kwargs["max_tokens"] == 10 + assert kwargs["temperature"] == 0.5 + + # check that the component returns the correct response + assert isinstance(response, dict) + assert "replies" in response + assert isinstance(response["replies"], list) + assert len(response["replies"]) == 1 + assert [isinstance(reply, ChatMessage) for reply in response["replies"]] + + @pytest.mark.skipif( + not (os.environ.get("REGION", None) or os.environ.get("PROJECT_ID", None)), + reason="Authenticate with GCP and set env variables REGION and PROJECT_ID to run this test.", + ) + @pytest.mark.integration + def test_live_run_wrong_model(self, chat_messages): + component = AnthropicVertexChatGenerator( + model="something-obviously-wrong", region=os.environ.get("REGION"), project_id=os.environ.get("PROJECT_ID") + ) + with pytest.raises(anthropic.NotFoundError): + component.run(chat_messages) + + @pytest.mark.skipif( + not (os.environ.get("REGION", None) or os.environ.get("PROJECT_ID", None)), + reason="Authenticate with GCP and set env variables REGION and PROJECT_ID to run this test.", + ) + @pytest.mark.integration + def test_default_inference_params(self, chat_messages): + client = AnthropicVertexChatGenerator( + region=os.environ.get("REGION"), project_id=os.environ.get("PROJECT_ID"), model="claude-3-sonnet@20240229" + ) + response = client.run(chat_messages) + + assert "replies" in response, "Response does not contain 'replies' key" + replies = response["replies"] + assert isinstance(replies, list), "Replies is not a list" + assert len(replies) > 0, "No replies received" + + first_reply = replies[0] + assert isinstance(first_reply, ChatMessage), "First reply is not a ChatMessage instance" + assert first_reply.content, "First reply has no content" + assert ChatMessage.is_from(first_reply, ChatRole.ASSISTANT), "First reply is not from the assistant" + assert "paris" in first_reply.content.lower(), "First reply does not contain 'paris'" + assert first_reply.meta, "First reply has no metadata" + + # Anthropic messages API is similar for AnthropicVertex and Anthropic endpoint, + # remaining tests are skipped for AnthropicVertexChatGenerator as they are already tested in AnthropicChatGenerator. From 67e08d0b7e5a7f51f52bb0d40fe40b0ff2caf43a Mon Sep 17 00:00:00 2001 From: Amna Mubashar Date: Fri, 15 Nov 2024 15:23:30 +0100 Subject: [PATCH 081/229] Enable kwargs in SearchIndex and Embedding Retriever (#1185) * Enable kwargs for semantic ranking --- .../azure_ai_search/example/document_store.py | 3 +- .../example/embedding_retrieval.py | 5 +- .../azure_ai_search/embedding_retriever.py | 41 +++++++++------ .../azure_ai_search/__init__.py | 4 +- .../azure_ai_search/document_store.py | 50 +++++++++++-------- .../azure_ai_search/filters.py | 2 +- .../azure_ai_search/tests/conftest.py | 22 ++++++-- 7 files changed, 78 insertions(+), 49 deletions(-) diff --git a/integrations/azure_ai_search/example/document_store.py b/integrations/azure_ai_search/example/document_store.py index 779f28935..92a641717 100644 --- a/integrations/azure_ai_search/example/document_store.py +++ b/integrations/azure_ai_search/example/document_store.py @@ -1,5 +1,4 @@ from haystack import Document -from haystack.document_stores.types import DuplicatePolicy from haystack_integrations.document_stores.azure_ai_search import AzureAISearchDocumentStore @@ -30,7 +29,7 @@ meta={"version": 2.0, "label": "chapter_three"}, ), ] -document_store.write_documents(documents, policy=DuplicatePolicy.SKIP) +document_store.write_documents(documents) filters = { "operator": "AND", diff --git a/integrations/azure_ai_search/example/embedding_retrieval.py b/integrations/azure_ai_search/example/embedding_retrieval.py index 088b08653..188f8525a 100644 --- a/integrations/azure_ai_search/example/embedding_retrieval.py +++ b/integrations/azure_ai_search/example/embedding_retrieval.py @@ -1,7 +1,6 @@ from haystack import Document, Pipeline from haystack.components.embedders import SentenceTransformersDocumentEmbedder, SentenceTransformersTextEmbedder from haystack.components.writers import DocumentWriter -from haystack.document_stores.types import DuplicatePolicy from haystack_integrations.components.retrievers.azure_ai_search import AzureAISearchEmbeddingRetriever from haystack_integrations.document_stores.azure_ai_search import AzureAISearchDocumentStore @@ -38,9 +37,7 @@ # Indexing Pipeline indexing_pipeline = Pipeline() indexing_pipeline.add_component(instance=document_embedder, name="doc_embedder") -indexing_pipeline.add_component( - instance=DocumentWriter(document_store=document_store, policy=DuplicatePolicy.SKIP), name="doc_writer" -) +indexing_pipeline.add_component(instance=DocumentWriter(document_store=document_store), name="doc_writer") indexing_pipeline.connect("doc_embedder", "doc_writer") indexing_pipeline.run({"doc_embedder": {"documents": documents}}) diff --git a/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/embedding_retriever.py b/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/embedding_retriever.py index ab649f874..af48b74fb 100644 --- a/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/embedding_retriever.py +++ b/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/embedding_retriever.py @@ -5,7 +5,7 @@ from haystack.document_stores.types import FilterPolicy from haystack.document_stores.types.filter_policy import apply_filter_policy -from haystack_integrations.document_stores.azure_ai_search import AzureAISearchDocumentStore, normalize_filters +from haystack_integrations.document_stores.azure_ai_search import AzureAISearchDocumentStore, _normalize_filters logger = logging.getLogger(__name__) @@ -25,16 +25,23 @@ def __init__( filters: Optional[Dict[str, Any]] = None, top_k: int = 10, filter_policy: Union[str, FilterPolicy] = FilterPolicy.REPLACE, + **kwargs, ): """ Create the AzureAISearchEmbeddingRetriever component. :param document_store: An instance of AzureAISearchDocumentStore to use with the Retriever. :param filters: Filters applied when fetching documents from the Document Store. - Filters are applied during the approximate kNN search to ensure the Retriever returns - `top_k` matching documents. :param top_k: Maximum number of documents to return. - :filter_policy: Policy to determine how filters are applied. Possible options: + :param filter_policy: Policy to determine how filters are applied. + :param kwargs: Additional keyword arguments to pass to the Azure AI's search endpoint. + Some of the supported parameters: + - `query_type`: A string indicating the type of query to perform. Possible values are + 'simple','full' and 'semantic'. + - `semantic_configuration_name`: The name of semantic configuration to be used when + processing semantic queries. + For more information on parameters, see the + [official Azure AI Search documentation](https://learn.microsoft.com/en-us/azure/search/). """ self._filters = filters or {} @@ -43,6 +50,7 @@ def __init__( self._filter_policy = ( filter_policy if isinstance(filter_policy, FilterPolicy) else FilterPolicy.from_str(filter_policy) ) + self._kwargs = kwargs if not isinstance(document_store, AzureAISearchDocumentStore): message = "document_store must be an instance of AzureAISearchDocumentStore" @@ -61,6 +69,7 @@ def to_dict(self) -> Dict[str, Any]: top_k=self._top_k, document_store=self._document_store.to_dict(), filter_policy=self._filter_policy.value, + **self._kwargs, ) @classmethod @@ -88,29 +97,31 @@ def from_dict(cls, data: Dict[str, Any]) -> "AzureAISearchEmbeddingRetriever": def run(self, query_embedding: List[float], filters: Optional[Dict[str, Any]] = None, top_k: Optional[int] = None): """Retrieve documents from the AzureAISearchDocumentStore. - :param query_embedding: floats representing the query embedding + :param query_embedding: A list of floats representing the query embedding. :param filters: Filters applied to the retrieved Documents. The way runtime filters are applied depends on - the `filter_policy` chosen at retriever initialization. See init method docstring for more - details. - :param top_k: the maximum number of documents to retrieve. - :returns: a dictionary with the following keys: - - `documents`: A list of documents retrieved from the AzureAISearchDocumentStore. + the `filter_policy` chosen at retriever initialization. See `__init__` method docstring for more + details. + :param top_k: The maximum number of documents to retrieve. + :returns: Dictionary with the following keys: + - `documents`: A list of documents retrieved from the AzureAISearchDocumentStore. """ top_k = top_k or self._top_k if filters is not None: applied_filters = apply_filter_policy(self._filter_policy, self._filters, filters) - normalized_filters = normalize_filters(applied_filters) + normalized_filters = _normalize_filters(applied_filters) else: normalized_filters = "" try: docs = self._document_store._embedding_retrieval( - query_embedding=query_embedding, - filters=normalized_filters, - top_k=top_k, + query_embedding=query_embedding, filters=normalized_filters, top_k=top_k, **self._kwargs ) except Exception as e: - raise e + msg = ( + "An error occurred during the embedding retrieval process from the AzureAISearchDocumentStore. " + "Ensure that the query embedding is valid and the document store is correctly configured." + ) + raise RuntimeError(msg) from e return {"documents": docs} diff --git a/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/__init__.py b/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/__init__.py index 635878a38..ca0ea7554 100644 --- a/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/__init__.py +++ b/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/__init__.py @@ -2,6 +2,6 @@ # # SPDX-License-Identifier: Apache-2.0 from .document_store import DEFAULT_VECTOR_SEARCH, AzureAISearchDocumentStore -from .filters import normalize_filters +from .filters import _normalize_filters -__all__ = ["AzureAISearchDocumentStore", "DEFAULT_VECTOR_SEARCH", "normalize_filters"] +__all__ = ["AzureAISearchDocumentStore", "DEFAULT_VECTOR_SEARCH", "_normalize_filters"] diff --git a/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/document_store.py b/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/document_store.py index 0b59b6e37..74260b4fa 100644 --- a/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/document_store.py +++ b/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/document_store.py @@ -31,7 +31,7 @@ from haystack.utils import Secret, deserialize_secrets_inplace from .errors import AzureAISearchDocumentStoreConfigError -from .filters import normalize_filters +from .filters import _normalize_filters type_mapping = { str: "Edm.String", @@ -70,7 +70,7 @@ def __init__( embedding_dimension: int = 768, metadata_fields: Optional[Dict[str, type]] = None, vector_search_configuration: VectorSearch = None, - **kwargs, + **index_creation_kwargs, ): """ A document store using [Azure AI Search](https://azure.microsoft.com/products/ai-services/ai-search/) @@ -87,19 +87,20 @@ def __init__( :param vector_search_configuration: Configuration option related to vector search. Default configuration uses the HNSW algorithm with cosine similarity to handle vector searches. - :param kwargs: Optional keyword parameters for Azure AI Search. - Some of the supported parameters: - - `api_version`: The Search API version to use for requests. - - `audience`: sets the Audience to use for authentication with Azure Active Directory (AAD). - The audience is not considered when using a shared key. If audience is not provided, - the public cloud audience will be assumed. + :param index_creation_kwargs: Optional keyword parameters to be passed to `SearchIndex` class + during index creation. Some of the supported parameters: + - `semantic_search`: Defines semantic configuration of the search index. This parameter is needed + to enable semantic search capabilities in index. + - `similarity`: The type of similarity algorithm to be used when scoring and ranking the documents + matching a search query. The similarity algorithm can only be defined at index creation time and + cannot be modified on existing indexes. - For more information on parameters, see the [official Azure AI Search documentation](https://learn.microsoft.com/en-us/azure/search/) + For more information on parameters, see the [official Azure AI Search documentation](https://learn.microsoft.com/en-us/azure/search/). """ azure_endpoint = azure_endpoint or os.environ.get("AZURE_SEARCH_SERVICE_ENDPOINT") or None if not azure_endpoint: - msg = "Please provide an Azure endpoint or set the environment variable AZURE_OPENAI_ENDPOINT." + msg = "Please provide an Azure endpoint or set the environment variable AZURE_SEARCH_SERVICE_ENDPOINT." raise ValueError(msg) api_key = api_key or os.environ.get("AZURE_SEARCH_API_KEY") or None @@ -114,7 +115,7 @@ def __init__( self._dummy_vector = [-10.0] * self._embedding_dimension self._metadata_fields = metadata_fields self._vector_search_configuration = vector_search_configuration or DEFAULT_VECTOR_SEARCH - self._kwargs = kwargs + self._index_creation_kwargs = index_creation_kwargs @property def client(self) -> SearchClient: @@ -128,7 +129,10 @@ def client(self) -> SearchClient: credential = AzureKeyCredential(resolved_key) if resolved_key else DefaultAzureCredential() try: if not self._index_client: - self._index_client = SearchIndexClient(resolved_endpoint, credential, **self._kwargs) + self._index_client = SearchIndexClient( + resolved_endpoint, + credential, + ) if not self._index_exists(self._index_name): # Create a new index if it does not exist logger.debug( @@ -151,7 +155,7 @@ def client(self) -> SearchClient: return self._client - def _create_index(self, index_name: str, **kwargs) -> None: + def _create_index(self, index_name: str) -> None: """ Creates a new search index. :param index_name: Name of the index to create. If None, the index name from the constructor is used. @@ -177,7 +181,10 @@ def _create_index(self, index_name: str, **kwargs) -> None: if self._metadata_fields: default_fields.extend(self._create_metadata_index_fields(self._metadata_fields)) index = SearchIndex( - name=index_name, fields=default_fields, vector_search=self._vector_search_configuration, **kwargs + name=index_name, + fields=default_fields, + vector_search=self._vector_search_configuration, + **self._index_creation_kwargs, ) if self._index_client: self._index_client.create_index(index) @@ -194,13 +201,13 @@ def to_dict(self) -> Dict[str, Any]: """ return default_to_dict( self, - azure_endpoint=self._azure_endpoint.to_dict() if self._azure_endpoint is not None else None, - api_key=self._api_key.to_dict() if self._api_key is not None else None, + azure_endpoint=self._azure_endpoint.to_dict() if self._azure_endpoint else None, + api_key=self._api_key.to_dict() if self._api_key else None, index_name=self._index_name, embedding_dimension=self._embedding_dimension, metadata_fields=self._metadata_fields, vector_search_configuration=self._vector_search_configuration.as_dict(), - **self._kwargs, + **self._index_creation_kwargs, ) @classmethod @@ -298,7 +305,7 @@ def filter_documents(self, filters: Optional[Dict[str, Any]] = None) -> List[Doc :returns: A list of Documents that match the given filters. """ if filters: - normalized_filters = normalize_filters(filters) + normalized_filters = _normalize_filters(filters) result = self.client.search(filter=normalized_filters) return self._convert_search_result_to_documents(result) else: @@ -409,8 +416,8 @@ def _embedding_retrieval( query_embedding: List[float], *, top_k: int = 10, - fields: Optional[List[str]] = None, filters: Optional[Dict[str, Any]] = None, + **kwargs, ) -> List[Document]: """ Retrieves documents that are most similar to the query embedding using a vector similarity metric. @@ -422,9 +429,10 @@ def _embedding_retrieval( `AzureAISearchEmbeddingRetriever` uses this method directly and is the public interface for it. :param query_embedding: Embedding of the query. + :param top_k: Maximum number of Documents to return, defaults to 10. :param filters: Filters applied to the retrieved Documents. Defaults to None. Filters are applied during the approximate kNN search to ensure that top_k matching documents are returned. - :param top_k: Maximum number of Documents to return, defaults to 10 + :param kwargs: Optional keyword arguments to pass to the Azure AI's search endpoint. :raises ValueError: If `query_embedding` is an empty list :returns: List of Document that are most similar to `query_embedding` @@ -435,6 +443,6 @@ def _embedding_retrieval( raise ValueError(msg) vector_query = VectorizedQuery(vector=query_embedding, k_nearest_neighbors=top_k, fields="embedding") - result = self.client.search(search_text=None, vector_queries=[vector_query], select=fields, filter=filters) + result = self.client.search(vector_queries=[vector_query], filter=filters, **kwargs) azure_docs = list(result) return self._convert_search_result_to_documents(azure_docs) diff --git a/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/filters.py b/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/filters.py index 650e3f8be..0f105bc91 100644 --- a/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/filters.py +++ b/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/filters.py @@ -7,7 +7,7 @@ LOGICAL_OPERATORS = {"AND": "and", "OR": "or", "NOT": "not"} -def normalize_filters(filters: Dict[str, Any]) -> str: +def _normalize_filters(filters: Dict[str, Any]) -> str: """ Converts Haystack filters in Azure AI Search compatible filters. """ diff --git a/integrations/azure_ai_search/tests/conftest.py b/integrations/azure_ai_search/tests/conftest.py index 3017c79c2..89369c87e 100644 --- a/integrations/azure_ai_search/tests/conftest.py +++ b/integrations/azure_ai_search/tests/conftest.py @@ -6,12 +6,14 @@ from azure.core.credentials import AzureKeyCredential from azure.core.exceptions import ResourceNotFoundError from azure.search.documents.indexes import SearchIndexClient +from haystack import logging from haystack.document_stores.types import DuplicatePolicy from haystack_integrations.document_stores.azure_ai_search import AzureAISearchDocumentStore # This is the approximate time in seconds it takes for the documents to be available in Azure Search index -SLEEP_TIME_IN_SECONDS = 5 +SLEEP_TIME_IN_SECONDS = 10 +MAX_WAIT_TIME_FOR_INDEX_DELETION = 5 @pytest.fixture() @@ -46,23 +48,35 @@ def document_store(request): # Override some methods to wait for the documents to be available original_write_documents = store.write_documents + original_delete_documents = store.delete_documents def write_documents_and_wait(documents, policy=DuplicatePolicy.OVERWRITE): written_docs = original_write_documents(documents, policy) time.sleep(SLEEP_TIME_IN_SECONDS) return written_docs - original_delete_documents = store.delete_documents - def delete_documents_and_wait(filters): original_delete_documents(filters) time.sleep(SLEEP_TIME_IN_SECONDS) + # Helper function to wait for the index to be deleted, needed to cover latency + def wait_for_index_deletion(client, index_name): + start_time = time.time() + while time.time() - start_time < MAX_WAIT_TIME_FOR_INDEX_DELETION: + if index_name not in client.list_index_names(): + return True + time.sleep(1) + return False + store.write_documents = write_documents_and_wait store.delete_documents = delete_documents_and_wait yield store try: client.delete_index(index_name) + if not wait_for_index_deletion(client, index_name): + logging.error(f"Index {index_name} was not properly deleted.") except ResourceNotFoundError: - pass + logging.info(f"Index {index_name} was already deleted or not found.") + except Exception as e: + logging.error(f"Unexpected error when deleting index {index_name}: {e}") From 1e3ac825be60c39dcc1ea06cb11ff0eac386656c Mon Sep 17 00:00:00 2001 From: alex-stoica Date: Mon, 18 Nov 2024 17:46:28 +0200 Subject: [PATCH 082/229] Fixed TypeError in LangfuseTrace (#1184) * Added parent_span functionality in trace method * solved PR comments * Readded "end()" for solving Latency issues * chore: fix ruff linting * Handle multiple runs * Fix indentation and span closing * Fix tests --------- Co-authored-by: Vladimir Blagojevic Co-authored-by: Silvano Cerza --- .../tracing/langfuse/tracer.py | 102 ++++++++++-------- integrations/langfuse/tests/test_tracer.py | 2 +- integrations/langfuse/tests/test_tracing.py | 31 +++--- 3 files changed, 77 insertions(+), 58 deletions(-) diff --git a/integrations/langfuse/src/haystack_integrations/tracing/langfuse/tracer.py b/integrations/langfuse/src/haystack_integrations/tracing/langfuse/tracer.py index c9c8a354e..c1f8d4d93 100644 --- a/integrations/langfuse/src/haystack_integrations/tracing/langfuse/tracer.py +++ b/integrations/langfuse/src/haystack_integrations/tracing/langfuse/tracer.py @@ -1,9 +1,10 @@ import contextlib -import logging import os +from contextvars import ContextVar from datetime import datetime -from typing import Any, Dict, Iterator, Optional, Union +from typing import Any, Dict, Iterator, List, Optional, Union +from haystack import logging from haystack.components.generators.openai_utils import _convert_message_to_openai_format from haystack.dataclasses import ChatMessage from haystack.tracing import Span, Tracer, tracer @@ -32,6 +33,17 @@ ] _ALL_SUPPORTED_GENERATORS = _SUPPORTED_GENERATORS + _SUPPORTED_CHAT_GENERATORS +# These are the keys used by Haystack for traces and span. +# We keep them here to avoid making typos when using them. +_PIPELINE_RUN_KEY = "haystack.pipeline.run" +_COMPONENT_NAME_KEY = "haystack.component.name" +_COMPONENT_TYPE_KEY = "haystack.component.type" +_COMPONENT_OUTPUT_KEY = "haystack.component.output" + +# Context var used to keep track of tracing related info. +# This mainly useful for parents spans. +tracing_context_var: ContextVar[Dict[Any, Any]] = ContextVar("tracing_context", default={}) + class LangfuseSpan(Span): """ @@ -86,7 +98,7 @@ def set_content_tag(self, key: str, value: Any) -> None: self._data[key] = value - def raw_span(self) -> Any: + def raw_span(self) -> "Union[langfuse.client.StatefulSpanClient, langfuse.client.StatefulTraceClient]": """ Return the underlying span instance. @@ -115,41 +127,57 @@ def __init__(self, tracer: "langfuse.Langfuse", name: str = "Haystack", public: and only accessible to the Langfuse account owner. """ self._tracer = tracer - self._context: list[LangfuseSpan] = [] + self._context: List[LangfuseSpan] = [] self._name = name self._public = public self.enforce_flush = os.getenv(HAYSTACK_LANGFUSE_ENFORCE_FLUSH_ENV_VAR, "true").lower() == "true" @contextlib.contextmanager - def trace(self, operation_name: str, tags: Optional[Dict[str, Any]] = None) -> Iterator[Span]: - """ - Start and manage a new trace span. - :param operation_name: The name of the operation. - :param tags: A dictionary of tags to attach to the span. - :return: A context manager yielding the span. - """ + def trace( + self, operation_name: str, tags: Optional[Dict[str, Any]] = None, parent_span: Optional[Span] = None + ) -> Iterator[Span]: tags = tags or {} - span_name = tags.get("haystack.component.name", operation_name) - - if tags.get("haystack.component.type") in _ALL_SUPPORTED_GENERATORS: - span = LangfuseSpan(self.current_span().raw_span().generation(name=span_name)) + span_name = tags.get(_COMPONENT_NAME_KEY, operation_name) + + # Create new span depending whether there's a parent span or not + if not parent_span: + if operation_name != _PIPELINE_RUN_KEY: + logger.warning( + "Creating a new trace without a parent span is not recommended for operation '{operation_name}'.", + operation_name=operation_name, + ) + # Create a new trace if no parent span is provided + span = LangfuseSpan( + self._tracer.trace( + name=self._name, + public=self._public, + id=tracing_context_var.get().get("trace_id"), + user_id=tracing_context_var.get().get("user_id"), + session_id=tracing_context_var.get().get("session_id"), + tags=tracing_context_var.get().get("tags"), + version=tracing_context_var.get().get("version"), + ) + ) + elif tags.get(_COMPONENT_TYPE_KEY) in _ALL_SUPPORTED_GENERATORS: + span = LangfuseSpan(parent_span.raw_span().generation(name=span_name)) else: - span = LangfuseSpan(self.current_span().raw_span().span(name=span_name)) + span = LangfuseSpan(parent_span.raw_span().span(name=span_name)) self._context.append(span) span.set_tags(tags) yield span - if tags.get("haystack.component.type") in _SUPPORTED_GENERATORS: - meta = span._data.get("haystack.component.output", {}).get("meta") + # Update span metadata based on component type + if tags.get(_COMPONENT_TYPE_KEY) in _SUPPORTED_GENERATORS: + # Haystack returns one meta dict for each message, but the 'usage' value + # is always the same, let's just pick the first item + meta = span._data.get(_COMPONENT_OUTPUT_KEY, {}).get("meta") if meta: - # Haystack returns one meta dict for each message, but the 'usage' value - # is always the same, let's just pick the first item m = meta[0] span._span.update(usage=m.get("usage") or None, model=m.get("model")) - elif tags.get("haystack.component.type") in _SUPPORTED_CHAT_GENERATORS: - replies = span._data.get("haystack.component.output", {}).get("replies") + elif tags.get(_COMPONENT_TYPE_KEY) in _SUPPORTED_CHAT_GENERATORS: + replies = span._data.get(_COMPONENT_OUTPUT_KEY, {}).get("replies") if replies: meta = replies[0].meta completion_start_time = meta.get("completion_start_time") @@ -165,36 +193,24 @@ def trace(self, operation_name: str, tags: Optional[Dict[str, Any]] = None) -> I completion_start_time=completion_start_time, ) - pipeline_input = tags.get("haystack.pipeline.input_data", None) - if pipeline_input: - span._span.update(input=tags["haystack.pipeline.input_data"]) - pipeline_output = tags.get("haystack.pipeline.output_data", None) - if pipeline_output: - span._span.update(output=tags["haystack.pipeline.output_data"]) - - span.raw_span().end() + raw_span = span.raw_span() + if isinstance(raw_span, langfuse.client.StatefulSpanClient): + raw_span.end() self._context.pop() - if len(self._context) == 1: - # The root span has to be a trace, which need to be removed from the context after the pipeline run - self._context.pop() - - if self.enforce_flush: - self.flush() + if self.enforce_flush: + self.flush() def flush(self): self._tracer.flush() - def current_span(self) -> Span: + def current_span(self) -> Optional[Span]: """ - Return the currently active span. + Return the current active span. - :return: The currently active span. + :return: The current span if available, else None. """ - if not self._context: - # The root span has to be a trace - self._context.append(LangfuseSpan(self._tracer.trace(name=self._name, public=self._public))) - return self._context[-1] + return self._context[-1] if self._context else None def get_trace_url(self) -> str: """ diff --git a/integrations/langfuse/tests/test_tracer.py b/integrations/langfuse/tests/test_tracer.py index 9ee8e5dc4..42ae1d07d 100644 --- a/integrations/langfuse/tests/test_tracer.py +++ b/integrations/langfuse/tests/test_tracer.py @@ -69,7 +69,7 @@ def test_create_new_span(self): tracer = LangfuseTracer(tracer=mock_tracer, name="Haystack", public=False) with tracer.trace("operation_name", tags={"tag1": "value1", "tag2": "value2"}) as span: - assert len(tracer._context) == 2, "The trace span should have been added to the the root context span" + assert len(tracer._context) == 1, "The trace span should have been added to the the root context span" assert span.raw_span().operation_name == "operation_name" assert span.raw_span().metadata == {"tag1": "value1", "tag2": "value2"} diff --git a/integrations/langfuse/tests/test_tracing.py b/integrations/langfuse/tests/test_tracing.py index 657b6eae1..e5737b861 100644 --- a/integrations/langfuse/tests/test_tracing.py +++ b/integrations/langfuse/tests/test_tracing.py @@ -52,25 +52,28 @@ def test_tracing_integration(llm_class, env_var, expected_trace): assert "Berlin" in response["llm"]["replies"][0].content assert response["tracer"]["trace_url"] - # add a random delay between 1 and 3 seconds to make sure the trace is flushed - # and that the trace is available in Langfuse when we fetch it below - time.sleep(random.uniform(1, 3)) - - url = "https://cloud.langfuse.com/api/public/traces/" trace_url = response["tracer"]["trace_url"] uuid = os.path.basename(urlparse(trace_url).path) + url = f"https://cloud.langfuse.com/api/public/traces/{uuid}" - try: - response = requests.get( - url + uuid, auth=HTTPBasicAuth(os.environ["LANGFUSE_PUBLIC_KEY"], os.environ["LANGFUSE_SECRET_KEY"]) + # Poll the Langfuse API a bit as the trace might not be ready right away + attempts = 5 + delay = 1 + while attempts >= 0: + res = requests.get( + url, auth=HTTPBasicAuth(os.environ["LANGFUSE_PUBLIC_KEY"], os.environ["LANGFUSE_SECRET_KEY"]) ) - assert response.status_code == 200, f"Failed to retrieve data from Langfuse API: {response.status_code}" + if attempts > 0 and res.status_code != 200: + attempts -= 1 + time.sleep(delay) + delay *= 2 + continue + assert res.status_code == 200, f"Failed to retrieve data from Langfuse API: {res.status_code}" # check if the trace contains the expected LLM name - assert expected_trace in str(response.content) + assert expected_trace in str(res.content) # check if the trace contains the expected generation span - assert "GENERATION" in str(response.content) + assert "GENERATION" in str(res.content) # check if the trace contains the expected user_id - assert "user_42" in str(response.content) - except requests.exceptions.RequestException as e: - pytest.fail(f"Failed to retrieve data from Langfuse API: {e}") + assert "user_42" in str(res.content) + break From 411441151e6f46cef0120324a6669c050639bac8 Mon Sep 17 00:00:00 2001 From: HaystackBot Date: Mon, 18 Nov 2024 15:51:37 +0000 Subject: [PATCH 083/229] Update the changelog --- integrations/langfuse/CHANGELOG.md | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/integrations/langfuse/CHANGELOG.md b/integrations/langfuse/CHANGELOG.md index 29be7f838..7cf1cc0c4 100644 --- a/integrations/langfuse/CHANGELOG.md +++ b/integrations/langfuse/CHANGELOG.md @@ -1,5 +1,15 @@ # Changelog +## [integrations/langfuse-v0.6.0] - 2024-11-18 + +### 🚀 Features + +- Add support for ttft (#1161) + +### ⚙️ Miscellaneous Tasks + +- Adopt uv as installer (#1142) + ## [integrations/langfuse-v0.5.0] - 2024-10-01 ### ⚙️ Miscellaneous Tasks From 180dd3b35d287ab04f1f380d75f5254bbe4bbfbb Mon Sep 17 00:00:00 2001 From: paulmartrencharpro <148542350+paulmartrencharpro@users.noreply.github.com> Date: Mon, 18 Nov 2024 17:09:52 +0100 Subject: [PATCH 084/229] README.md: add new example with the ranker and fix the old ones (#1198) ranker_example.py: fix the example --- integrations/fastembed/README.md | 37 ++++++++++++++++--- .../fastembed/examples/ranker_example.py | 2 +- 2 files changed, 33 insertions(+), 6 deletions(-) diff --git a/integrations/fastembed/README.md b/integrations/fastembed/README.md index c021dec3b..f3c2bb135 100644 --- a/integrations/fastembed/README.md +++ b/integrations/fastembed/README.md @@ -8,6 +8,7 @@ **Table of Contents** - [Installation](#installation) +- [Usage](#Usage) - [License](#license) ## Installation @@ -33,7 +34,7 @@ embedding = text_embedder.run(text)["embedding"] ```python from haystack_integrations.components.embedders.fastembed import FastembedDocumentEmbedder -from haystack.dataclasses import Document +from haystack import Document embedder = FastembedDocumentEmbedder( model="BAAI/bge-small-en-v1.5", @@ -50,24 +51,50 @@ from haystack_integrations.components.embedders.fastembed import FastembedSparse text = "fastembed is supported by and maintained by Qdrant." text_embedder = FastembedSparseTextEmbedder( - model="prithvida/Splade_PP_en_v1" + model="prithivida/Splade_PP_en_v1" ) text_embedder.warm_up() -embedding = text_embedder.run(text)["embedding"] +embedding = text_embedder.run(text)["sparse_embedding"] ``` ```python from haystack_integrations.components.embedders.fastembed import FastembedSparseDocumentEmbedder -from haystack.dataclasses import Document +from haystack import Document embedder = FastembedSparseDocumentEmbedder( - model="prithvida/Splade_PP_en_v1", + model="prithivida/Splade_PP_en_v1", ) embedder.warm_up() doc = Document(content="fastembed is supported by and maintained by Qdrant.", meta={"long_answer": "no",}) result = embedder.run(documents=[doc]) ``` +You can use `FastembedRanker` by importing as: + +```python +from haystack import Document + +from haystack_integrations.components.rankers.fastembed import FastembedRanker + +query = "Who is maintaining Qdrant?" +documents = [ + Document( + content="This is built to be faster and lighter than other embedding libraries e.g. Transformers, Sentence-Transformers, etc." + ), + Document(content="fastembed is supported by and maintained by Qdrant."), +] + +ranker = FastembedRanker(model_name="Xenova/ms-marco-MiniLM-L-6-v2") +ranker.warm_up() +reranked_documents = ranker.run(query=query, documents=documents)["documents"] + +print(reranked_documents[0]) + +# Document(id=..., +# content: 'fastembed is supported by and maintained by Qdrant.', +# score: 5.472434997558594..) +``` + ## License `fastembed-haystack` is distributed under the terms of the [Apache-2.0](https://spdx.org/licenses/Apache-2.0.html) license. diff --git a/integrations/fastembed/examples/ranker_example.py b/integrations/fastembed/examples/ranker_example.py index 7a31e4646..593334e90 100644 --- a/integrations/fastembed/examples/ranker_example.py +++ b/integrations/fastembed/examples/ranker_example.py @@ -15,7 +15,7 @@ reranked_documents = ranker.run(query=query, documents=documents)["documents"] -print(reranked_documents["documents"][0]) +print(reranked_documents[0]) # Document(id=..., # content: 'fastembed is supported by and maintained by Qdrant.', From 3ed8dfb742f65904981493f8edfe642a99e63f84 Mon Sep 17 00:00:00 2001 From: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com> Date: Tue, 19 Nov 2024 10:23:08 +0100 Subject: [PATCH 085/229] ci: Fix changelog generation (#1199) * Fix changelog creation to also show PRs that don't have conventional commits * Fix file glob for changelog creation * Handle commits with multiple lines * Add PR number * Hide commits that just update the changelog * Fix sorting --- .github/workflows/CI_pypi_release.yml | 2 +- cliff.toml | 74 ++++++++++++++++++++------- 2 files changed, 57 insertions(+), 19 deletions(-) diff --git a/.github/workflows/CI_pypi_release.yml b/.github/workflows/CI_pypi_release.yml index e29d3ea36..1162ca3b3 100644 --- a/.github/workflows/CI_pypi_release.yml +++ b/.github/workflows/CI_pypi_release.yml @@ -54,7 +54,7 @@ jobs: with: config: cliff.toml args: > - --include-path "${{ steps.pathfinder.outputs.project_path }}/*" + --include-path "${{ steps.pathfinder.outputs.project_path }}/**/*" --tag-pattern "${{ steps.pathfinder.outputs.project_path }}-v*" - name: Commit changelog diff --git a/cliff.toml b/cliff.toml index 29228543e..45e4c647b 100644 --- a/cliff.toml +++ b/cliff.toml @@ -19,11 +19,43 @@ body = """ ## [unreleased] {% endif %}\ {% for group, commits in commits | group_by(attribute="group") %} + {# + Skip the whole section if it contains only a single commit + and it's the commit that updated the changelog. + If we don't do this we get an empty section since we don't show + commits that update the changelog + #}\ + {% if commits | length == 1 and commits[0].message == 'Update the changelog' %}\ + {% continue %}\ + {% endif %}\ ### {{ group | striptags | trim | upper_first }} - {% for commit in commits %} + {% for commit in commits %}\ + {# + Skip commits that update the changelog, they're not useful to the user + #}\ + {% if commit.message == 'Update the changelog' %}\ + {% continue %}\ + {% endif %} - {% if commit.scope %}*({{ commit.scope }})* {% endif %}\ {% if commit.breaking %}[**breaking**] {% endif %}\ - {{ commit.message | upper_first }}\ + {# + We first try to render the conventional commit message if present. + If it's not a conventional commit we get the PR title if present. + If the commit is neither conventional, nor has a PR title set + we fallback to whatever the commit message is. + + We do this cause when merging PRs with multiple commits that don't + have a title following conventional commit guidelines we might get + a commit message that is multiple lines. That makes the changelog + look a bit funky so we handle it like so. + #}\ + {% if commit.conventional %}\ + {{ commit.message | upper_first }}\ + {% elif commit.remote.pr_title %}\ + {{ commit.remote.pr_title | upper_first }} (#{{ commit.remote.pr_number }})\ + {% else %}\ + {{ commit.message | upper_first }}\ + {% endif %}\ {% endfor %} {% endfor %}\n """ @@ -35,7 +67,7 @@ footer = """ trim = true # postprocessors postprocessors = [ - # { pattern = '', replace = "https://github.com/orhun/git-cliff" }, # replace repository URL + # { pattern = '', replace = "https://github.com/orhun/git-cliff" }, # replace repository URL ] [git] @@ -47,24 +79,26 @@ filter_unconventional = false split_commits = false # regex for preprocessing the commit messages commit_preprocessors = [ - # Replace issue numbers - #{ pattern = '\((\w+\s)?#([0-9]+)\)', replace = "([#${2}](/issues/${2}))"}, - # Check spelling of the commit with https://github.com/crate-ci/typos - # If the spelling is incorrect, it will be automatically fixed. - #{ pattern = '.*', replace_command = 'typos --write-changes -' }, + # Replace issue numbers + #{ pattern = '\((\w+\s)?#([0-9]+)\)', replace = "([#${2}](/issues/${2}))"}, + # Check spelling of the commit with https://github.com/crate-ci/typos + # If the spelling is incorrect, it will be automatically fixed. + #{ pattern = '.*', replace_command = 'typos --write-changes -' }, ] # regex for parsing and grouping commits commit_parsers = [ - { message = "^feat", group = "🚀 Features" }, - { message = "^fix", group = "🐛 Bug Fixes" }, - { message = "^doc", group = "📚 Documentation" }, - { message = "^perf", group = "⚡ Performance" }, - { message = "^refactor", group = "🚜 Refactor" }, - { message = "^style", group = "🎨 Styling" }, - { message = "^test", group = "🧪 Testing" }, - { message = "^chore|^ci", group = "⚙️ Miscellaneous Tasks" }, - { body = ".*security", group = "🛡️ Security" }, - { message = "^revert", group = "◀️ Revert" }, + { message = "^feat", group = "🚀 Features" }, + { message = "^fix", group = "🐛 Bug Fixes" }, + { message = "^refactor", group = "🚜 Refactor" }, + { message = "^doc", group = "📚 Documentation" }, + { message = "^perf", group = "⚡ Performance" }, + { message = "^style", group = "🎨 Styling" }, + { message = "^test", group = "🧪 Testing" }, + { body = ".*security", group = "🛡️ Security" }, + { message = "^revert", group = "◀️ Revert" }, + { message = "^ci", group = "⚙️ CI" }, + { message = "^chore", group = "🧹 Chores" }, + { message = ".*", group = "🌀 Miscellaneous" }, ] # protect breaking changes from being skipped due to matching a skipping commit_parser protect_breaking_commits = false @@ -82,3 +116,7 @@ topo_order = false sort_commits = "oldest" # limit the number of commits included in the changelog. # limit_commits = 42 + +[remote.github] +owner = "deepset-ai" +repo = "haystack-core-integrations" From d00e7d91fe5fd0e08b73cb0bd6a5a51ca0a60c34 Mon Sep 17 00:00:00 2001 From: theoohoho <31537466+theoohoho@users.noreply.github.com> Date: Tue, 19 Nov 2024 23:38:11 +0800 Subject: [PATCH 086/229] fix: Fix missing usage metadata in GoogleAIGeminiChatGenerator (#1195) * fix: Fix missing usage metadata in GoogleAIGeminiChatGenerator * small fixes + test --------- Co-authored-by: anakin87 --- .../generators/google_ai/chat/gemini.py | 15 +++++++++++++++ .../tests/generators/chat/test_chat_gemini.py | 10 ++++++++-- 2 files changed, 23 insertions(+), 2 deletions(-) diff --git a/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/chat/gemini.py b/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/chat/gemini.py index 8efa8cda7..dbcab619d 100644 --- a/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/chat/gemini.py +++ b/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/chat/gemini.py @@ -313,9 +313,24 @@ def _get_response(self, response_body: GenerateContentResponse) -> List[ChatMess """ replies: List[ChatMessage] = [] metadata = response_body.to_dict() + + # currently Google only supports one candidate and usage metadata reflects this + # this should be refactored when multiple candidates are supported + usage_metadata_openai_format = {} + + usage_metadata = metadata.get("usage_metadata") + if usage_metadata: + usage_metadata_openai_format = { + "prompt_tokens": usage_metadata["prompt_token_count"], + "completion_tokens": usage_metadata["candidates_token_count"], + "total_tokens": usage_metadata["total_token_count"], + } + for idx, candidate in enumerate(response_body.candidates): candidate_metadata = metadata["candidates"][idx] candidate_metadata.pop("content", None) # we remove content from the metadata + if usage_metadata_openai_format: + candidate_metadata["usage"] = usage_metadata_openai_format for part in candidate.content.parts: if part.text != "": diff --git a/integrations/google_ai/tests/generators/chat/test_chat_gemini.py b/integrations/google_ai/tests/generators/chat/test_chat_gemini.py index c4372db0d..cb42f0ff8 100644 --- a/integrations/google_ai/tests/generators/chat/test_chat_gemini.py +++ b/integrations/google_ai/tests/generators/chat/test_chat_gemini.py @@ -295,5 +295,11 @@ def test_past_conversation(): ] response = gemini_chat.run(messages=messages) assert "replies" in response - assert len(response["replies"]) > 0 - assert all(reply.role == ChatRole.ASSISTANT for reply in response["replies"]) + replies = response["replies"] + assert len(replies) > 0 + assert all(reply.role == ChatRole.ASSISTANT for reply in replies) + + assert all("usage" in reply.meta for reply in replies) + assert all("prompt_tokens" in reply.meta["usage"] for reply in replies) + assert all("completion_tokens" in reply.meta["usage"] for reply in replies) + assert all("total_tokens" in reply.meta["usage"] for reply in replies) From 8b7918b488ff28710793e9a7a5c6df8c672f5a34 Mon Sep 17 00:00:00 2001 From: HaystackBot Date: Tue, 19 Nov 2024 16:08:08 +0000 Subject: [PATCH 087/229] Update the changelog --- integrations/google_ai/CHANGELOG.md | 47 ++++++++++++++++++++++++----- 1 file changed, 40 insertions(+), 7 deletions(-) diff --git a/integrations/google_ai/CHANGELOG.md b/integrations/google_ai/CHANGELOG.md index 8f09db79a..7171b0069 100644 --- a/integrations/google_ai/CHANGELOG.md +++ b/integrations/google_ai/CHANGELOG.md @@ -1,15 +1,23 @@ # Changelog +## [integrations/google_ai-v3.0.2] - 2024-11-19 + +### 🐛 Bug Fixes + +- Fix missing usage metadata in GoogleAIGeminiChatGenerator (#1195) + + ## [integrations/google_ai-v3.0.0] - 2024-11-12 ### 🐛 Bug Fixes - `GoogleAIGeminiGenerator` - remove support for tools and change output type (#1177) -### ⚙️ Miscellaneous Tasks +### ⚙️ CI - Adopt uv as installer (#1142) + ## [integrations/google_ai-v2.0.1] - 2024-10-15 ### 🚀 Features @@ -26,16 +34,22 @@ - Do not retry tests in `hatch run test` command (#954) -### ⚙️ Miscellaneous Tasks +### ⚙️ CI - Retry tests to reduce flakyness (#836) + +### 🧹 Chores + - Update ruff invocation to include check parameter (#853) - Update ruff linting scripts and settings (#1105) -### Docs +### 🌀 Miscellaneous +- Ci: install `pytest-rerunfailures` where needed; add retry config to `test-cov` script (#845) +- Fix Google AI tests failing (#885) - Update GeminiGenerator docstrings (#964) - Update GoogleChatGenerator docstrings (#962) +- Feat: enable streaming in GoogleAIGemini (#1016) ## [integrations/google_ai-v1.1.0] - 2024-06-05 @@ -43,31 +57,50 @@ - Handle `TypeError: Could not create Blob` in `GoogleAIGeminiChatGenerator` (#772) +### 🌀 Miscellaneous + +- Chore: add license classifiers (#680) +- Chore: change the pydoc renderer class (#718) +- Fix Google AI integration tests (#786) +- Test: Fix tests skipping for Google AI integration (#788) + ## [integrations/google_ai-v1.0.0] - 2024-03-27 ### 🐛 Bug Fixes - Fix order of API docs (#447) -This PR will also push the docs to Readme - ### 📚 Documentation - Update category slug (#442) - Disable-class-def (#556) +### 🌀 Miscellaneous + +- Google AI - review docstrings (#533) +- Make tests show coverage (#566) +- Remove references to Python 3.7 (#601) +- Google Generators: change `answers` to `replies` (#626) + ## [integrations/google_ai-v0.2.0] - 2024-02-15 -### Google_ai +### 🌀 Miscellaneous - Create api docs (#354) +- Google AI - new secrets management (#424) ## [integrations/google_ai-v0.1.0] - 2024-01-25 -### Refact +### 🌀 Miscellaneous +- Add docstrings for `GoogleAIGeminiGenerator` and `GoogleAIGeminiChatGenerator` (#175) - [**breaking**] Adjust import paths (#268) ## [integrations/google_ai-v0.0.1] - 2024-01-03 +### 🌀 Miscellaneous + +- Gemini with Makersuite (#156) +- Fix google_ai integration versioning + From 8755e4a52c4f6b506a5333968b1189b265d7cfc0 Mon Sep 17 00:00:00 2001 From: paulmartrencharpro <148542350+paulmartrencharpro@users.noreply.github.com> Date: Tue, 19 Nov 2024 17:11:25 +0100 Subject: [PATCH 088/229] The name of the model prithvida/Splade_PP_en_v1 has a typo and is being replaced by prithivida/Splade_PP_en_v1? There's a Deprecation warning about it. (#1201) I changed it everywhere in the class & the tests --- .../fastembed_sparse_document_embedder.py | 6 ++-- .../fastembed_sparse_text_embedder.py | 6 ++-- ...test_fastembed_sparse_document_embedder.py | 36 +++++++++---------- .../test_fastembed_sparse_text_embedder.py | 32 ++++++++--------- 4 files changed, 40 insertions(+), 40 deletions(-) diff --git a/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/fastembed_sparse_document_embedder.py b/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/fastembed_sparse_document_embedder.py index f79f08c90..a30d43cf4 100644 --- a/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/fastembed_sparse_document_embedder.py +++ b/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/fastembed_sparse_document_embedder.py @@ -16,7 +16,7 @@ class FastembedSparseDocumentEmbedder: from haystack.dataclasses import Document sparse_doc_embedder = FastembedSparseDocumentEmbedder( - model="prithvida/Splade_PP_en_v1", + model="prithivida/Splade_PP_en_v1", batch_size=32, ) @@ -53,7 +53,7 @@ class FastembedSparseDocumentEmbedder: def __init__( self, - model: str = "prithvida/Splade_PP_en_v1", + model: str = "prithivida/Splade_PP_en_v1", cache_dir: Optional[str] = None, threads: Optional[int] = None, batch_size: int = 32, @@ -68,7 +68,7 @@ def __init__( Create an FastembedDocumentEmbedder component. :param model: Local path or name of the model in Hugging Face's model hub, - such as `prithvida/Splade_PP_en_v1`. + such as `prithivida/Splade_PP_en_v1`. :param cache_dir: The path to the cache directory. Can be set using the `FASTEMBED_CACHE_PATH` env variable. Defaults to `fastembed_cache` in the system's temp directory. diff --git a/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/fastembed_sparse_text_embedder.py b/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/fastembed_sparse_text_embedder.py index 2ebab35b4..c7296525f 100644 --- a/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/fastembed_sparse_text_embedder.py +++ b/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/fastembed_sparse_text_embedder.py @@ -19,7 +19,7 @@ class FastembedSparseTextEmbedder: "The disk comes and it does not, only Windows. Do Not order this if you have a Mac!!") sparse_text_embedder = FastembedSparseTextEmbedder( - model="prithvida/Splade_PP_en_v1" + model="prithivida/Splade_PP_en_v1" ) sparse_text_embedder.warm_up() @@ -29,7 +29,7 @@ class FastembedSparseTextEmbedder: def __init__( self, - model: str = "prithvida/Splade_PP_en_v1", + model: str = "prithivida/Splade_PP_en_v1", cache_dir: Optional[str] = None, threads: Optional[int] = None, progress_bar: bool = True, @@ -40,7 +40,7 @@ def __init__( """ Create a FastembedSparseTextEmbedder component. - :param model: Local path or name of the model in Fastembed's model hub, such as `prithvida/Splade_PP_en_v1` + :param model: Local path or name of the model in Fastembed's model hub, such as `prithivida/Splade_PP_en_v1` :param cache_dir: The path to the cache directory. Can be set using the `FASTEMBED_CACHE_PATH` env variable. Defaults to `fastembed_cache` in the system's temp directory. diff --git a/integrations/fastembed/tests/test_fastembed_sparse_document_embedder.py b/integrations/fastembed/tests/test_fastembed_sparse_document_embedder.py index 90e94908d..7c0de196a 100644 --- a/integrations/fastembed/tests/test_fastembed_sparse_document_embedder.py +++ b/integrations/fastembed/tests/test_fastembed_sparse_document_embedder.py @@ -15,8 +15,8 @@ def test_init_default(self): """ Test default initialization parameters for FastembedSparseDocumentEmbedder. """ - embedder = FastembedSparseDocumentEmbedder(model="prithvida/Splade_PP_en_v1") - assert embedder.model_name == "prithvida/Splade_PP_en_v1" + embedder = FastembedSparseDocumentEmbedder(model="prithivida/Splade_PP_en_v1") + assert embedder.model_name == "prithivida/Splade_PP_en_v1" assert embedder.cache_dir is None assert embedder.threads is None assert embedder.batch_size == 32 @@ -31,7 +31,7 @@ def test_init_with_parameters(self): Test custom initialization parameters for FastembedSparseDocumentEmbedder. """ embedder = FastembedSparseDocumentEmbedder( - model="prithvida/Splade_PP_en_v1", + model="prithivida/Splade_PP_en_v1", cache_dir="fake_dir", threads=2, batch_size=64, @@ -41,7 +41,7 @@ def test_init_with_parameters(self): meta_fields_to_embed=["test_field"], embedding_separator=" | ", ) - assert embedder.model_name == "prithvida/Splade_PP_en_v1" + assert embedder.model_name == "prithivida/Splade_PP_en_v1" assert embedder.cache_dir == "fake_dir" assert embedder.threads == 2 assert embedder.batch_size == 64 @@ -55,12 +55,12 @@ def test_to_dict(self): """ Test serialization of FastembedSparseDocumentEmbedder to a dictionary, using default initialization parameters. """ - embedder = FastembedSparseDocumentEmbedder(model="prithvida/Splade_PP_en_v1") + embedder = FastembedSparseDocumentEmbedder(model="prithivida/Splade_PP_en_v1") embedder_dict = embedder.to_dict() assert embedder_dict == { "type": "haystack_integrations.components.embedders.fastembed.fastembed_sparse_document_embedder.FastembedSparseDocumentEmbedder", # noqa "init_parameters": { - "model": "prithvida/Splade_PP_en_v1", + "model": "prithivida/Splade_PP_en_v1", "cache_dir": None, "threads": None, "batch_size": 32, @@ -78,7 +78,7 @@ def test_to_dict_with_custom_init_parameters(self): Test serialization of FastembedSparseDocumentEmbedder to a dictionary, using custom initialization parameters. """ embedder = FastembedSparseDocumentEmbedder( - model="prithvida/Splade_PP_en_v1", + model="prithivida/Splade_PP_en_v1", cache_dir="fake_dir", threads=2, batch_size=64, @@ -92,7 +92,7 @@ def test_to_dict_with_custom_init_parameters(self): assert embedder_dict == { "type": "haystack_integrations.components.embedders.fastembed.fastembed_sparse_document_embedder.FastembedSparseDocumentEmbedder", # noqa "init_parameters": { - "model": "prithvida/Splade_PP_en_v1", + "model": "prithivida/Splade_PP_en_v1", "cache_dir": "fake_dir", "threads": 2, "batch_size": 64, @@ -113,7 +113,7 @@ def test_from_dict(self): embedder_dict = { "type": "haystack_integrations.components.embedders.fastembed.fastembed_sparse_document_embedder.FastembedSparseDocumentEmbedder", # noqa "init_parameters": { - "model": "prithvida/Splade_PP_en_v1", + "model": "prithivida/Splade_PP_en_v1", "cache_dir": None, "threads": None, "batch_size": 32, @@ -125,7 +125,7 @@ def test_from_dict(self): }, } embedder = default_from_dict(FastembedSparseDocumentEmbedder, embedder_dict) - assert embedder.model_name == "prithvida/Splade_PP_en_v1" + assert embedder.model_name == "prithivida/Splade_PP_en_v1" assert embedder.cache_dir is None assert embedder.threads is None assert embedder.batch_size == 32 @@ -143,7 +143,7 @@ def test_from_dict_with_custom_init_parameters(self): embedder_dict = { "type": "haystack_integrations.components.embedders.fastembed.fastembed_sparse_document_embedder.FastembedSparseDocumentEmbedder", # noqa "init_parameters": { - "model": "prithvida/Splade_PP_en_v1", + "model": "prithivida/Splade_PP_en_v1", "cache_dir": "fake_dir", "threads": 2, "batch_size": 64, @@ -155,7 +155,7 @@ def test_from_dict_with_custom_init_parameters(self): }, } embedder = default_from_dict(FastembedSparseDocumentEmbedder, embedder_dict) - assert embedder.model_name == "prithvida/Splade_PP_en_v1" + assert embedder.model_name == "prithivida/Splade_PP_en_v1" assert embedder.cache_dir == "fake_dir" assert embedder.threads == 2 assert embedder.batch_size == 64 @@ -172,11 +172,11 @@ def test_warmup(self, mocked_factory): """ Test for checking embedder instances after warm-up. """ - embedder = FastembedSparseDocumentEmbedder(model="prithvida/Splade_PP_en_v1") + embedder = FastembedSparseDocumentEmbedder(model="prithivida/Splade_PP_en_v1") mocked_factory.get_embedding_backend.assert_not_called() embedder.warm_up() mocked_factory.get_embedding_backend.assert_called_once_with( - model_name="prithvida/Splade_PP_en_v1", + model_name="prithivida/Splade_PP_en_v1", cache_dir=None, threads=None, local_files_only=False, @@ -190,7 +190,7 @@ def test_warmup_does_not_reload(self, mocked_factory): """ Test for checking backend instances after multiple warm-ups. """ - embedder = FastembedSparseDocumentEmbedder(model="prithvida/Splade_PP_en_v1") + embedder = FastembedSparseDocumentEmbedder(model="prithivida/Splade_PP_en_v1") mocked_factory.get_embedding_backend.assert_not_called() embedder.warm_up() embedder.warm_up() @@ -211,7 +211,7 @@ def test_embed(self): """ Test for checking output dimensions and embedding dimensions. """ - embedder = FastembedSparseDocumentEmbedder(model="prithvida/Splade_PP_en_v1") + embedder = FastembedSparseDocumentEmbedder(model="prithivida/Splade_PP_en_v1") embedder.embedding_backend = MagicMock() embedder.embedding_backend.embed = lambda x, **kwargs: self._generate_mocked_sparse_embedding( # noqa: ARG005 len(x) @@ -235,7 +235,7 @@ def test_embed_incorrect_input_format(self): """ Test for checking incorrect input format when creating embedding. """ - embedder = FastembedSparseDocumentEmbedder(model="prithvida/Splade_PP_en_v1") + embedder = FastembedSparseDocumentEmbedder(model="prithivida/Splade_PP_en_v1") string_input = "text" list_integers_input = [1, 2, 3] @@ -330,7 +330,7 @@ def test_run_with_model_kwargs(self): @pytest.mark.integration def test_run(self): embedder = FastembedSparseDocumentEmbedder( - model="prithvida/Splade_PP_en_v1", + model="prithivida/Splade_PP_en_v1", ) embedder.warm_up() diff --git a/integrations/fastembed/tests/test_fastembed_sparse_text_embedder.py b/integrations/fastembed/tests/test_fastembed_sparse_text_embedder.py index 4f438fd15..9b73f5f3a 100644 --- a/integrations/fastembed/tests/test_fastembed_sparse_text_embedder.py +++ b/integrations/fastembed/tests/test_fastembed_sparse_text_embedder.py @@ -15,8 +15,8 @@ def test_init_default(self): """ Test default initialization parameters for FastembedSparseTextEmbedder. """ - embedder = FastembedSparseTextEmbedder(model="prithvida/Splade_PP_en_v1") - assert embedder.model_name == "prithvida/Splade_PP_en_v1" + embedder = FastembedSparseTextEmbedder(model="prithivida/Splade_PP_en_v1") + assert embedder.model_name == "prithivida/Splade_PP_en_v1" assert embedder.cache_dir is None assert embedder.threads is None assert embedder.progress_bar is True @@ -27,13 +27,13 @@ def test_init_with_parameters(self): Test custom initialization parameters for FastembedSparseTextEmbedder. """ embedder = FastembedSparseTextEmbedder( - model="prithvida/Splade_PP_en_v1", + model="prithivida/Splade_PP_en_v1", cache_dir="fake_dir", threads=2, progress_bar=False, parallel=1, ) - assert embedder.model_name == "prithvida/Splade_PP_en_v1" + assert embedder.model_name == "prithivida/Splade_PP_en_v1" assert embedder.cache_dir == "fake_dir" assert embedder.threads == 2 assert embedder.progress_bar is False @@ -43,12 +43,12 @@ def test_to_dict(self): """ Test serialization of FastembedSparseTextEmbedder to a dictionary, using default initialization parameters. """ - embedder = FastembedSparseTextEmbedder(model="prithvida/Splade_PP_en_v1") + embedder = FastembedSparseTextEmbedder(model="prithivida/Splade_PP_en_v1") embedder_dict = embedder.to_dict() assert embedder_dict == { "type": "haystack_integrations.components.embedders.fastembed.fastembed_sparse_text_embedder.FastembedSparseTextEmbedder", # noqa "init_parameters": { - "model": "prithvida/Splade_PP_en_v1", + "model": "prithivida/Splade_PP_en_v1", "cache_dir": None, "threads": None, "progress_bar": True, @@ -63,7 +63,7 @@ def test_to_dict_with_custom_init_parameters(self): Test serialization of FastembedSparseTextEmbedder to a dictionary, using custom initialization parameters. """ embedder = FastembedSparseTextEmbedder( - model="prithvida/Splade_PP_en_v1", + model="prithivida/Splade_PP_en_v1", cache_dir="fake_dir", threads=2, progress_bar=False, @@ -74,7 +74,7 @@ def test_to_dict_with_custom_init_parameters(self): assert embedder_dict == { "type": "haystack_integrations.components.embedders.fastembed.fastembed_sparse_text_embedder.FastembedSparseTextEmbedder", # noqa "init_parameters": { - "model": "prithvida/Splade_PP_en_v1", + "model": "prithivida/Splade_PP_en_v1", "cache_dir": "fake_dir", "threads": 2, "progress_bar": False, @@ -91,7 +91,7 @@ def test_from_dict(self): embedder_dict = { "type": "haystack_integrations.components.embedders.fastembed.fastembed_sparse_text_embedder.FastembedSparseTextEmbedder", # noqa "init_parameters": { - "model": "prithvida/Splade_PP_en_v1", + "model": "prithivida/Splade_PP_en_v1", "cache_dir": None, "threads": None, "progress_bar": True, @@ -99,7 +99,7 @@ def test_from_dict(self): }, } embedder = default_from_dict(FastembedSparseTextEmbedder, embedder_dict) - assert embedder.model_name == "prithvida/Splade_PP_en_v1" + assert embedder.model_name == "prithivida/Splade_PP_en_v1" assert embedder.cache_dir is None assert embedder.threads is None assert embedder.progress_bar is True @@ -112,7 +112,7 @@ def test_from_dict_with_custom_init_parameters(self): embedder_dict = { "type": "haystack_integrations.components.embedders.fastembed.fastembed_sparse_text_embedder.FastembedSparseTextEmbedder", # noqa "init_parameters": { - "model": "prithvida/Splade_PP_en_v1", + "model": "prithivida/Splade_PP_en_v1", "cache_dir": "fake_dir", "threads": 2, "progress_bar": False, @@ -120,7 +120,7 @@ def test_from_dict_with_custom_init_parameters(self): }, } embedder = default_from_dict(FastembedSparseTextEmbedder, embedder_dict) - assert embedder.model_name == "prithvida/Splade_PP_en_v1" + assert embedder.model_name == "prithivida/Splade_PP_en_v1" assert embedder.cache_dir == "fake_dir" assert embedder.threads == 2 assert embedder.progress_bar is False @@ -133,11 +133,11 @@ def test_warmup(self, mocked_factory): """ Test for checking embedder instances after warm-up. """ - embedder = FastembedSparseTextEmbedder(model="prithvida/Splade_PP_en_v1") + embedder = FastembedSparseTextEmbedder(model="prithivida/Splade_PP_en_v1") mocked_factory.get_embedding_backend.assert_not_called() embedder.warm_up() mocked_factory.get_embedding_backend.assert_called_once_with( - model_name="prithvida/Splade_PP_en_v1", + model_name="prithivida/Splade_PP_en_v1", cache_dir=None, threads=None, local_files_only=False, @@ -151,7 +151,7 @@ def test_warmup_does_not_reload(self, mocked_factory): """ Test for checking backend instances after multiple warm-ups. """ - embedder = FastembedSparseTextEmbedder(model="prithvida/Splade_PP_en_v1") + embedder = FastembedSparseTextEmbedder(model="prithivida/Splade_PP_en_v1") mocked_factory.get_embedding_backend.assert_not_called() embedder.warm_up() embedder.warm_up() @@ -252,7 +252,7 @@ def test_run_with_model_kwargs(self): @pytest.mark.integration def test_run(self): embedder = FastembedSparseTextEmbedder( - model="prithvida/Splade_PP_en_v1", + model="prithivida/Splade_PP_en_v1", ) embedder.warm_up() From 51abafea60804b74e480b46090a9955fbe0839e3 Mon Sep 17 00:00:00 2001 From: HaystackBot Date: Tue, 19 Nov 2024 16:12:29 +0000 Subject: [PATCH 089/229] Update the changelog --- integrations/fastembed/CHANGELOG.md | 62 ++++++++++++++++++++++++++--- 1 file changed, 57 insertions(+), 5 deletions(-) diff --git a/integrations/fastembed/CHANGELOG.md b/integrations/fastembed/CHANGELOG.md index 5dd62d130..841781660 100644 --- a/integrations/fastembed/CHANGELOG.md +++ b/integrations/fastembed/CHANGELOG.md @@ -1,11 +1,23 @@ # Changelog +## [integrations/fastembed-v1.4.1] - 2024-11-19 + +### 🌀 Miscellaneous + +- Add new example with the ranker and fix the old ones (#1198) +- Fix: Fastembed - Change default Sparse model as the used one is deprecated due to a typo (#1201) + ## [integrations/fastembed-v1.4.0] - 2024-11-13 -### ⚙️ Miscellaneous Tasks +### ⚙️ CI - Adopt uv as installer (#1142) +### 🌀 Miscellaneous + +- Chore: fastembed - pin `onnxruntime<1.20.0` (#1164) +- Feat: Fastembed - add FastembedRanker (#1178) + ## [integrations/fastembed-v1.3.0] - 2024-10-07 ### 🚀 Features @@ -16,20 +28,35 @@ - Do not retry tests in `hatch run test` command (#954) -### ⚙️ Miscellaneous Tasks +### ⚙️ CI - Retry tests to reduce flakyness (#836) + +### 🧹 Chores + - Update ruff invocation to include check parameter (#853) - Update ruff linting scripts and settings (#1105) -### Fix +### 🌀 Miscellaneous +- Ci: install `pytest-rerunfailures` where needed; add retry config to `test-cov` script (#845) - Typo on Sparse embedders. The parameter should be "progress_bar" … (#814) +- Chore: fastembed - ruff update, don't ruff tests (#997) ## [integrations/fastembed-v1.1.0] - 2024-05-15 +### 🌀 Miscellaneous + +- Chore: change the pydoc renderer class (#718) +- Use the local_files_only option available as of fastembed==0.2.7. It … (#736) + ## [integrations/fastembed-v1.0.0] - 2024-05-06 +### 🌀 Miscellaneous + +- Chore: add license classifiers (#680) +- `FastembedSparseTextEmbedder` - remove `batch_size` (#688) + ## [integrations/fastembed-v0.1.0] - 2024-04-10 ### 🚀 Features @@ -40,6 +67,10 @@ - Disable-class-def (#556) +### 🌀 Miscellaneous + +- Remove references to Python 3.7 (#601) + ## [integrations/fastembed-v0.0.6] - 2024-03-07 ### 📚 Documentation @@ -47,20 +78,32 @@ - Review and normalize docstrings - `integrations.fastembed` (#519) - Small consistency improvements (#536) +### 🌀 Miscellaneous + +- Docs: Fix `integrations.fastembed` API docs (#540) +- Improvements to FastEmbed integration (#558) + ## [integrations/fastembed-v0.0.5] - 2024-02-20 ### 🐛 Bug Fixes - Fix order of API docs (#447) -This PR will also push the docs to Readme - ### 📚 Documentation - Update category slug (#442) +### 🌀 Miscellaneous + +- Fastembed integration new parameters (#446) + ## [integrations/fastembed-v0.0.4] - 2024-02-16 +### 🌀 Miscellaneous + +- Fastembed integration: add example (#401) +- Fastembed fix: add parallel (#403) + ## [integrations/fastembed-v0.0.3] - 2024-02-12 ### 🐛 Bug Fixes @@ -73,6 +116,15 @@ This PR will also push the docs to Readme ## [integrations/fastembed-v0.0.2] - 2024-02-11 +### 🌀 Miscellaneous + +- Updated labeler and readme (#389) +- Fastembed fix: added prefix and suffix (#390) + ## [integrations/fastembed-v0.0.1] - 2024-02-10 +### 🌀 Miscellaneous + +- Add Fastembed Embeddings integration (#383) + From 472ada87f202d53c3107c37f6c9f7a7603e1b19e Mon Sep 17 00:00:00 2001 From: Trivan Menezes <47679108+ttmenezes@users.noreply.github.com> Date: Wed, 20 Nov 2024 06:08:41 -0800 Subject: [PATCH 090/229] feat: Add BM25 and Hybrid Search Retrievers to Azure AI Search Integration (#1175) * add new retrievers and tests --- .../retrievers/azure_ai_search/__init__.py | 4 +- .../azure_ai_search/bm25_retriever.py | 135 +++++++++++ .../azure_ai_search/embedding_retriever.py | 3 +- .../azure_ai_search/hybrid_retriever.py | 139 ++++++++++++ .../azure_ai_search/document_store.py | 93 +++++++- .../tests/test_bm25_retriever.py | 175 +++++++++++++++ .../tests/test_embedding_retriever.py | 60 +++++ .../tests/test_hybrid_retriever.py | 211 ++++++++++++++++++ 8 files changed, 811 insertions(+), 9 deletions(-) create mode 100644 integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/bm25_retriever.py create mode 100644 integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/hybrid_retriever.py create mode 100644 integrations/azure_ai_search/tests/test_bm25_retriever.py create mode 100644 integrations/azure_ai_search/tests/test_hybrid_retriever.py diff --git a/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/__init__.py b/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/__init__.py index eb75ffa6c..56dc30db4 100644 --- a/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/__init__.py +++ b/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/__init__.py @@ -1,3 +1,5 @@ +from .bm25_retriever import AzureAISearchBM25Retriever from .embedding_retriever import AzureAISearchEmbeddingRetriever +from .hybrid_retriever import AzureAISearchHybridRetriever -__all__ = ["AzureAISearchEmbeddingRetriever"] +__all__ = ["AzureAISearchBM25Retriever", "AzureAISearchEmbeddingRetriever", "AzureAISearchHybridRetriever"] diff --git a/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/bm25_retriever.py b/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/bm25_retriever.py new file mode 100644 index 000000000..4a1c7f98c --- /dev/null +++ b/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/bm25_retriever.py @@ -0,0 +1,135 @@ +import logging +from typing import Any, Dict, List, Optional, Union + +from haystack import Document, component, default_from_dict, default_to_dict +from haystack.document_stores.types import FilterPolicy +from haystack.document_stores.types.filter_policy import apply_filter_policy + +from haystack_integrations.document_stores.azure_ai_search import AzureAISearchDocumentStore, _normalize_filters + +logger = logging.getLogger(__name__) + + +@component +class AzureAISearchBM25Retriever: + """ + Retrieves documents from the AzureAISearchDocumentStore using BM25 retrieval. + Must be connected to the AzureAISearchDocumentStore to run. + + """ + + def __init__( + self, + *, + document_store: AzureAISearchDocumentStore, + filters: Optional[Dict[str, Any]] = None, + top_k: int = 10, + filter_policy: Union[str, FilterPolicy] = FilterPolicy.REPLACE, + **kwargs, + ): + """ + Create the AzureAISearchBM25Retriever component. + + :param document_store: An instance of AzureAISearchDocumentStore to use with the Retriever. + :param filters: Filters applied when fetching documents from the Document Store. + Filters are applied during the BM25 search to ensure the Retriever returns + `top_k` matching documents. + :param top_k: Maximum number of documents to return. + :param filter_policy: Policy to determine how filters are applied. + :param kwargs: Additional keyword arguments to pass to the Azure AI's search endpoint. + Some of the supported parameters: + - `query_type`: A string indicating the type of query to perform. Possible values are + 'simple','full' and 'semantic'. + - `semantic_configuration_name`: The name of semantic configuration to be used when + processing semantic queries. + For more information on parameters, see the + [official Azure AI Search documentation](https://learn.microsoft.com/en-us/azure/search/). + :raises TypeError: If the document store is not an instance of AzureAISearchDocumentStore. + :raises RuntimeError: If the query is not valid, or if the document store is not correctly configured. + + """ + self._filters = filters or {} + self._top_k = top_k + self._document_store = document_store + self._filter_policy = ( + filter_policy if isinstance(filter_policy, FilterPolicy) else FilterPolicy.from_str(filter_policy) + ) + self._kwargs = kwargs + if not isinstance(document_store, AzureAISearchDocumentStore): + message = "document_store must be an instance of AzureAISearchDocumentStore" + raise TypeError(message) + + def to_dict(self) -> Dict[str, Any]: + """ + Serializes the component to a dictionary. + + :returns: + Dictionary with serialized data. + """ + return default_to_dict( + self, + filters=self._filters, + top_k=self._top_k, + document_store=self._document_store.to_dict(), + filter_policy=self._filter_policy.value, + **self._kwargs, + ) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "AzureAISearchBM25Retriever": + """ + Deserializes the component from a dictionary. + + :param data: + Dictionary to deserialize from. + + :returns: + Deserialized component. + """ + data["init_parameters"]["document_store"] = AzureAISearchDocumentStore.from_dict( + data["init_parameters"]["document_store"] + ) + + # Pipelines serialized with old versions of the component might not + # have the filter_policy field. + if "filter_policy" in data["init_parameters"]: + data["init_parameters"]["filter_policy"] = FilterPolicy.from_str(data["init_parameters"]["filter_policy"]) + return default_from_dict(cls, data) + + @component.output_types(documents=List[Document]) + def run(self, query: str, filters: Optional[Dict[str, Any]] = None, top_k: Optional[int] = None): + """Retrieve documents from the AzureAISearchDocumentStore. + + :param query: Text of the query. + :param filters: Filters applied to the retrieved Documents. The way runtime filters are applied depends on + the `filter_policy` chosen at retriever initialization. See init method docstring for more + details. + :param top_k: the maximum number of documents to retrieve. + :raises RuntimeError: If an error occurs during the BM25 retrieval process. + :returns: a dictionary with the following keys: + - `documents`: A list of documents retrieved from the AzureAISearchDocumentStore. + """ + + top_k = top_k or self._top_k + filters = filters or self._filters + if filters: + applied_filters = apply_filter_policy(self._filter_policy, self._filters, filters) + normalized_filters = _normalize_filters(applied_filters) + else: + normalized_filters = "" + + try: + docs = self._document_store._bm25_retrieval( + query=query, + filters=normalized_filters, + top_k=top_k, + **self._kwargs, + ) + except Exception as e: + msg = ( + "An error occurred during the bm25 retrieval process from the AzureAISearchDocumentStore. " + "Ensure that the query is valid and the document store is correctly configured." + ) + raise RuntimeError(msg) from e + + return {"documents": docs} diff --git a/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/embedding_retriever.py b/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/embedding_retriever.py index af48b74fb..69fad7208 100644 --- a/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/embedding_retriever.py +++ b/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/embedding_retriever.py @@ -107,7 +107,8 @@ def run(self, query_embedding: List[float], filters: Optional[Dict[str, Any]] = """ top_k = top_k or self._top_k - if filters is not None: + filters = filters or self._filters + if filters: applied_filters = apply_filter_policy(self._filter_policy, self._filters, filters) normalized_filters = _normalize_filters(applied_filters) else: diff --git a/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/hybrid_retriever.py b/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/hybrid_retriever.py new file mode 100644 index 000000000..79282933f --- /dev/null +++ b/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/hybrid_retriever.py @@ -0,0 +1,139 @@ +import logging +from typing import Any, Dict, List, Optional, Union + +from haystack import Document, component, default_from_dict, default_to_dict +from haystack.document_stores.types import FilterPolicy +from haystack.document_stores.types.filter_policy import apply_filter_policy + +from haystack_integrations.document_stores.azure_ai_search import AzureAISearchDocumentStore, _normalize_filters + +logger = logging.getLogger(__name__) + + +@component +class AzureAISearchHybridRetriever: + """ + Retrieves documents from the AzureAISearchDocumentStore using a hybrid (vector + BM25) retrieval. + Must be connected to the AzureAISearchDocumentStore to run. + + """ + + def __init__( + self, + *, + document_store: AzureAISearchDocumentStore, + filters: Optional[Dict[str, Any]] = None, + top_k: int = 10, + filter_policy: Union[str, FilterPolicy] = FilterPolicy.REPLACE, + **kwargs, + ): + """ + Create the AzureAISearchHybridRetriever component. + + :param document_store: An instance of AzureAISearchDocumentStore to use with the Retriever. + :param filters: Filters applied when fetching documents from the Document Store. + Filters are applied during the hybrid search to ensure the Retriever returns + `top_k` matching documents. + :param top_k: Maximum number of documents to return. + :param filter_policy: Policy to determine how filters are applied. + :param kwargs: Additional keyword arguments to pass to the Azure AI's search endpoint. + Some of the supported parameters: + - `query_type`: A string indicating the type of query to perform. Possible values are + 'simple','full' and 'semantic'. + - `semantic_configuration_name`: The name of semantic configuration to be used when + processing semantic queries. + For more information on parameters, see the + [official Azure AI Search documentation](https://learn.microsoft.com/en-us/azure/search/). + :raises TypeError: If the document store is not an instance of AzureAISearchDocumentStore. + :raises RuntimeError: If query or query_embedding are invalid, or if document store is not correctly configured. + """ + self._filters = filters or {} + self._top_k = top_k + self._document_store = document_store + self._filter_policy = ( + filter_policy if isinstance(filter_policy, FilterPolicy) else FilterPolicy.from_str(filter_policy) + ) + self._kwargs = kwargs + + if not isinstance(document_store, AzureAISearchDocumentStore): + message = "document_store must be an instance of AzureAISearchDocumentStore" + raise TypeError(message) + + def to_dict(self) -> Dict[str, Any]: + """ + Serializes the component to a dictionary. + + :returns: + Dictionary with serialized data. + """ + return default_to_dict( + self, + filters=self._filters, + top_k=self._top_k, + document_store=self._document_store.to_dict(), + filter_policy=self._filter_policy.value, + **self._kwargs, + ) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "AzureAISearchHybridRetriever": + """ + Deserializes the component from a dictionary. + + :param data: + Dictionary to deserialize from. + + :returns: + Deserialized component. + """ + data["init_parameters"]["document_store"] = AzureAISearchDocumentStore.from_dict( + data["init_parameters"]["document_store"] + ) + + # Pipelines serialized with old versions of the component might not + # have the filter_policy field. + if "filter_policy" in data["init_parameters"]: + data["init_parameters"]["filter_policy"] = FilterPolicy.from_str(data["init_parameters"]["filter_policy"]) + return default_from_dict(cls, data) + + @component.output_types(documents=List[Document]) + def run( + self, + query: str, + query_embedding: List[float], + filters: Optional[Dict[str, Any]] = None, + top_k: Optional[int] = None, + ): + """Retrieve documents from the AzureAISearchDocumentStore. + + :param query: Text of the query. + :param query_embedding: A list of floats representing the query embedding + :param filters: Filters applied to the retrieved Documents. The way runtime filters are applied depends on + the `filter_policy` chosen at retriever initialization. See `__init__` method docstring for more + details. + :param top_k: The maximum number of documents to retrieve. + :raises RuntimeError: If an error occurs during the hybrid retrieval process. + :returns: A dictionary with the following keys: + - `documents`: A list of documents retrieved from the AzureAISearchDocumentStore. + """ + + top_k = top_k or self._top_k + filters = filters or self._filters + if filters: + applied_filters = apply_filter_policy(self._filter_policy, self._filters, filters) + normalized_filters = _normalize_filters(applied_filters) + else: + normalized_filters = "" + + try: + docs = self._document_store._hybrid_retrieval( + query=query, query_embedding=query_embedding, filters=normalized_filters, top_k=top_k, **self._kwargs + ) + except Exception as e: + msg = ( + "An error occurred during the hybrid retrieval process from the AzureAISearchDocumentStore. " + "Ensure that the query and query_embedding are valid and the document store is correctly configured." + ) + raise RuntimeError(msg) from e + + return {"documents": docs} diff --git a/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/document_store.py b/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/document_store.py index 74260b4fa..137ff621c 100644 --- a/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/document_store.py +++ b/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/document_store.py @@ -240,6 +240,9 @@ def write_documents(self, documents: List[Document], policy: DuplicatePolicy = D Writes the provided documents to search index. :param documents: documents to write to the index. + :param policy: Policy to determine how duplicates are handled. + :raises ValueError: If the documents are not of type Document. + :raises TypeError: If the document ids are not strings. :return: the number of documents added to index. """ @@ -247,7 +250,7 @@ def _convert_input_document(documents: Document): document_dict = asdict(documents) if not isinstance(document_dict["id"], str): msg = f"Document id {document_dict['id']} is not a string, " - raise Exception(msg) + raise TypeError(msg) index_document = self._convert_haystack_documents_to_azure(document_dict) return index_document @@ -421,7 +424,7 @@ def _embedding_retrieval( ) -> List[Document]: """ Retrieves documents that are most similar to the query embedding using a vector similarity metric. - It uses the vector configuration of the document store. By default it uses the HNSW algorithm + It uses the vector configuration specified in the document store. By default, it uses the HNSW algorithm with cosine similarity. This method is not meant to be part of the public interface of @@ -429,13 +432,12 @@ def _embedding_retrieval( `AzureAISearchEmbeddingRetriever` uses this method directly and is the public interface for it. :param query_embedding: Embedding of the query. - :param top_k: Maximum number of Documents to return, defaults to 10. - :param filters: Filters applied to the retrieved Documents. Defaults to None. - Filters are applied during the approximate kNN search to ensure that top_k matching documents are returned. + :param top_k: Maximum number of Documents to return. + :param filters: Filters applied to the retrieved Documents. :param kwargs: Optional keyword arguments to pass to the Azure AI's search endpoint. - :raises ValueError: If `query_embedding` is an empty list - :returns: List of Document that are most similar to `query_embedding` + :raises ValueError: If `query_embedding` is an empty list. + :returns: List of Document that are most similar to `query_embedding`. """ if not query_embedding: @@ -446,3 +448,80 @@ def _embedding_retrieval( result = self.client.search(vector_queries=[vector_query], filter=filters, **kwargs) azure_docs = list(result) return self._convert_search_result_to_documents(azure_docs) + + def _bm25_retrieval( + self, + query: str, + top_k: int = 10, + filters: Optional[Dict[str, Any]] = None, + **kwargs, + ) -> List[Document]: + """ + Retrieves documents that are most similar to `query`, using the BM25 algorithm. + + This method is not meant to be part of the public interface of + `AzureAISearchDocumentStore` nor called directly. + `AzureAISearchBM25Retriever` uses this method directly and is the public interface for it. + + :param query: Text of the query. + :param filters: Filters applied to the retrieved Documents. + :param top_k: Maximum number of Documents to return. + :param kwargs: Optional keyword arguments to pass to the Azure AI's search endpoint. + + + :raises ValueError: If `query` is an empty string. + :returns: List of Document that are most similar to `query`. + """ + + if query is None: + msg = "query must not be None" + raise ValueError(msg) + + result = self.client.search(search_text=query, filter=filters, top=top_k, query_type="simple", **kwargs) + azure_docs = list(result) + return self._convert_search_result_to_documents(azure_docs) + + def _hybrid_retrieval( + self, + query: str, + query_embedding: List[float], + top_k: int = 10, + filters: Optional[Dict[str, Any]] = None, + **kwargs, + ) -> List[Document]: + """ + Retrieves documents similar to query using the vector configuration in the document store and + the BM25 algorithm. This method combines vector similarity and BM25 for improved retrieval. + + This method is not meant to be part of the public interface of + `AzureAISearchDocumentStore` nor called directly. + `AzureAISearchHybridRetriever` uses this method directly and is the public interface for it. + + :param query: Text of the query. + :param query_embedding: Embedding of the query. + :param filters: Filters applied to the retrieved Documents. + :param top_k: Maximum number of Documents to return. + :param kwargs: Optional keyword arguments to pass to the Azure AI's search endpoint. + + :raises ValueError: If `query` or `query_embedding` is empty. + :returns: List of Document that are most similar to `query`. + """ + + if query is None: + msg = "query must not be None" + raise ValueError(msg) + if not query_embedding: + msg = "query_embedding must be a non-empty list of floats" + raise ValueError(msg) + + vector_query = VectorizedQuery(vector=query_embedding, k_nearest_neighbors=top_k, fields="embedding") + result = self.client.search( + search_text=query, + vector_queries=[vector_query], + filter=filters, + top=top_k, + query_type="simple", + **kwargs, + ) + azure_docs = list(result) + return self._convert_search_result_to_documents(azure_docs) diff --git a/integrations/azure_ai_search/tests/test_bm25_retriever.py b/integrations/azure_ai_search/tests/test_bm25_retriever.py new file mode 100644 index 000000000..6ebb20949 --- /dev/null +++ b/integrations/azure_ai_search/tests/test_bm25_retriever.py @@ -0,0 +1,175 @@ +# SPDX-FileCopyrightText: 2023-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 +import os +from unittest.mock import Mock + +import pytest +from haystack.dataclasses import Document +from haystack.document_stores.types import FilterPolicy + +from haystack_integrations.components.retrievers.azure_ai_search import AzureAISearchBM25Retriever +from haystack_integrations.document_stores.azure_ai_search import AzureAISearchDocumentStore + + +def test_init_default(): + mock_store = Mock(spec=AzureAISearchDocumentStore) + retriever = AzureAISearchBM25Retriever(document_store=mock_store) + assert retriever._document_store == mock_store + assert retriever._filters == {} + assert retriever._top_k == 10 + assert retriever._filter_policy == FilterPolicy.REPLACE + + retriever = AzureAISearchBM25Retriever(document_store=mock_store, filter_policy="replace") + assert retriever._filter_policy == FilterPolicy.REPLACE + + with pytest.raises(ValueError): + AzureAISearchBM25Retriever(document_store=mock_store, filter_policy="unknown") + + +def test_to_dict(): + document_store = AzureAISearchDocumentStore(hosts="some fake host") + retriever = AzureAISearchBM25Retriever(document_store=document_store) + res = retriever.to_dict() + assert res == { + "type": "haystack_integrations.components.retrievers.azure_ai_search.bm25_retriever.AzureAISearchBM25Retriever", + "init_parameters": { + "filters": {}, + "top_k": 10, + "document_store": { + "type": "haystack_integrations.document_stores.azure_ai_search.document_store.AzureAISearchDocumentStore", # noqa: E501 + "init_parameters": { + "azure_endpoint": { + "type": "env_var", + "env_vars": ["AZURE_SEARCH_SERVICE_ENDPOINT"], + "strict": True, + }, + "api_key": {"type": "env_var", "env_vars": ["AZURE_SEARCH_API_KEY"], "strict": False}, + "index_name": "default", + "embedding_dimension": 768, + "metadata_fields": None, + "vector_search_configuration": { + "profiles": [ + {"name": "default-vector-config", "algorithm_configuration_name": "cosine-algorithm-config"} + ], + "algorithms": [ + { + "name": "cosine-algorithm-config", + "kind": "hnsw", + "parameters": {"m": 4, "ef_construction": 400, "ef_search": 500, "metric": "cosine"}, + } + ], + }, + "hosts": "some fake host", + }, + }, + "filter_policy": "replace", + }, + } + + +def test_from_dict(): + data = { + "type": "haystack_integrations.components.retrievers.azure_ai_search.bm25_retriever.AzureAISearchBM25Retriever", + "init_parameters": { + "filters": {}, + "top_k": 10, + "document_store": { + "type": "haystack_integrations.document_stores.azure_ai_search.document_store.AzureAISearchDocumentStore", # noqa: E501 + "init_parameters": { + "azure_endpoint": { + "type": "env_var", + "env_vars": ["AZURE_SEARCH_SERVICE_ENDPOINT"], + "strict": True, + }, + "api_key": {"type": "env_var", "env_vars": ["AZURE_SEARCH_API_KEY"], "strict": False}, + "index_name": "default", + "metadata_fields": None, + "hosts": "some fake host", + }, + }, + "filter_policy": "replace", + }, + } + retriever = AzureAISearchBM25Retriever.from_dict(data) + assert isinstance(retriever._document_store, AzureAISearchDocumentStore) + assert retriever._filters == {} + assert retriever._top_k == 10 + assert retriever._filter_policy == FilterPolicy.REPLACE + + +def test_run(): + mock_store = Mock(spec=AzureAISearchDocumentStore) + mock_store._bm25_retrieval.return_value = [Document(content="Test doc")] + retriever = AzureAISearchBM25Retriever(document_store=mock_store) + res = retriever.run(query="Test query") + mock_store._bm25_retrieval.assert_called_once_with( + query="Test query", + filters="", + top_k=10, + ) + assert len(res) == 1 + assert len(res["documents"]) == 1 + assert res["documents"][0].content == "Test doc" + + +def test_run_init_params(): + mock_store = Mock(spec=AzureAISearchDocumentStore) + mock_store._bm25_retrieval.return_value = [Document(content="Test doc")] + retriever = AzureAISearchBM25Retriever( + document_store=mock_store, filters={"field": "type", "operator": "==", "value": "article"}, top_k=11 + ) + res = retriever.run(query="Test query") + mock_store._bm25_retrieval.assert_called_once_with( + query="Test query", + filters="type eq 'article'", + top_k=11, + ) + assert len(res) == 1 + assert len(res["documents"]) == 1 + assert res["documents"][0].content == "Test doc" + + +def test_run_time_params(): + mock_store = Mock(spec=AzureAISearchDocumentStore) + mock_store._bm25_retrieval.return_value = [Document(content="Test doc")] + retriever = AzureAISearchBM25Retriever( + document_store=mock_store, + filters={"field": "type", "operator": "==", "value": "article"}, + top_k=11, + select="name", + ) + res = retriever.run(query="Test query", filters={"field": "type", "operator": "==", "value": "book"}, top_k=5) + mock_store._bm25_retrieval.assert_called_once_with( + query="Test query", filters="type eq 'book'", top_k=5, select="name" + ) + assert len(res) == 1 + assert len(res["documents"]) == 1 + assert res["documents"][0].content == "Test doc" + + +@pytest.mark.skipif( + not os.environ.get("AZURE_SEARCH_SERVICE_ENDPOINT", None) and not os.environ.get("AZURE_SEARCH_API_KEY", None), + reason="Missing AZURE_SEARCH_SERVICE_ENDPOINT or AZURE_SEARCH_API_KEY.", +) +@pytest.mark.integration +class TestRetriever: + + def test_run(self, document_store: AzureAISearchDocumentStore): + docs = [Document(id="1", content="Test document")] + document_store.write_documents(docs) + retriever = AzureAISearchBM25Retriever(document_store=document_store) + res = retriever.run(query="Test document") + assert res["documents"] == docs + + def test_document_retrieval(self, document_store: AzureAISearchDocumentStore): + docs = [ + Document(content="This is first document"), + Document(content="This is second document"), + Document(content="This is third document"), + ] + + document_store.write_documents(docs) + retriever = AzureAISearchBM25Retriever(document_store=document_store) + results = retriever.run(query="This is first document") + assert results["documents"][0].content == "This is first document" diff --git a/integrations/azure_ai_search/tests/test_embedding_retriever.py b/integrations/azure_ai_search/tests/test_embedding_retriever.py index d4615ec44..576ecda08 100644 --- a/integrations/azure_ai_search/tests/test_embedding_retriever.py +++ b/integrations/azure_ai_search/tests/test_embedding_retriever.py @@ -103,6 +103,66 @@ def test_from_dict(): assert retriever._filter_policy == FilterPolicy.REPLACE +def test_run(): + mock_store = Mock(spec=AzureAISearchDocumentStore) + mock_store._embedding_retrieval.return_value = [Document(content="Test doc", embedding=[0.1, 0.2])] + retriever = AzureAISearchEmbeddingRetriever(document_store=mock_store) + res = retriever.run(query_embedding=[0.5, 0.7]) + mock_store._embedding_retrieval.assert_called_once_with( + query_embedding=[0.5, 0.7], + filters="", + top_k=10, + ) + assert len(res) == 1 + assert len(res["documents"]) == 1 + assert res["documents"][0].content == "Test doc" + assert res["documents"][0].embedding == [0.1, 0.2] + + +def test_run_init_params(): + mock_store = Mock(spec=AzureAISearchDocumentStore) + mock_store._embedding_retrieval.return_value = [Document(content="Test doc", embedding=[0.1, 0.2])] + retriever = AzureAISearchEmbeddingRetriever( + document_store=mock_store, + filters={"field": "type", "operator": "==", "value": "article"}, + top_k=11, + ) + res = retriever.run(query_embedding=[0.5, 0.7]) + mock_store._embedding_retrieval.assert_called_once_with( + query_embedding=[0.5, 0.7], + filters="type eq 'article'", + top_k=11, + ) + assert len(res) == 1 + assert len(res["documents"]) == 1 + assert res["documents"][0].content == "Test doc" + assert res["documents"][0].embedding == [0.1, 0.2] + + +def test_run_time_params(): + mock_store = Mock(spec=AzureAISearchDocumentStore) + mock_store._embedding_retrieval.return_value = [Document(content="Test doc", embedding=[0.1, 0.2])] + retriever = AzureAISearchEmbeddingRetriever( + document_store=mock_store, + filters={"field": "type", "operator": "==", "value": "article"}, + top_k=11, + select="name", + ) + res = retriever.run( + query_embedding=[0.5, 0.7], filters={"field": "type", "operator": "==", "value": "book"}, top_k=9 + ) + mock_store._embedding_retrieval.assert_called_once_with( + query_embedding=[0.5, 0.7], + filters="type eq 'book'", + top_k=9, + select="name", + ) + assert len(res) == 1 + assert len(res["documents"]) == 1 + assert res["documents"][0].content == "Test doc" + assert res["documents"][0].embedding == [0.1, 0.2] + + @pytest.mark.skipif( not os.environ.get("AZURE_SEARCH_SERVICE_ENDPOINT", None) and not os.environ.get("AZURE_SEARCH_API_KEY", None), reason="Missing AZURE_SEARCH_SERVICE_ENDPOINT or AZURE_SEARCH_API_KEY.", diff --git a/integrations/azure_ai_search/tests/test_hybrid_retriever.py b/integrations/azure_ai_search/tests/test_hybrid_retriever.py new file mode 100644 index 000000000..bf305c4fe --- /dev/null +++ b/integrations/azure_ai_search/tests/test_hybrid_retriever.py @@ -0,0 +1,211 @@ +# SPDX-FileCopyrightText: 2023-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 +import os +from typing import List +from unittest.mock import Mock + +import pytest +from azure.core.exceptions import HttpResponseError +from haystack.dataclasses import Document +from haystack.document_stores.types import FilterPolicy +from numpy.random import rand # type: ignore + +from haystack_integrations.components.retrievers.azure_ai_search import AzureAISearchHybridRetriever +from haystack_integrations.document_stores.azure_ai_search import DEFAULT_VECTOR_SEARCH, AzureAISearchDocumentStore + + +def test_init_default(): + mock_store = Mock(spec=AzureAISearchDocumentStore) + retriever = AzureAISearchHybridRetriever(document_store=mock_store) + assert retriever._document_store == mock_store + assert retriever._filters == {} + assert retriever._top_k == 10 + assert retriever._filter_policy == FilterPolicy.REPLACE + + retriever = AzureAISearchHybridRetriever(document_store=mock_store, filter_policy="replace") + assert retriever._filter_policy == FilterPolicy.REPLACE + + with pytest.raises(ValueError): + AzureAISearchHybridRetriever(document_store=mock_store, filter_policy="unknown") + + +def test_to_dict(): + document_store = AzureAISearchDocumentStore(hosts="some fake host") + retriever = AzureAISearchHybridRetriever(document_store=document_store) + res = retriever.to_dict() + assert res == { + "type": "haystack_integrations.components.retrievers.azure_ai_search.hybrid_retriever.AzureAISearchHybridRetriever", # noqa: E501 + "init_parameters": { + "filters": {}, + "top_k": 10, + "document_store": { + "type": "haystack_integrations.document_stores.azure_ai_search.document_store.AzureAISearchDocumentStore", # noqa: E501 + "init_parameters": { + "azure_endpoint": { + "type": "env_var", + "env_vars": ["AZURE_SEARCH_SERVICE_ENDPOINT"], + "strict": True, + }, + "api_key": {"type": "env_var", "env_vars": ["AZURE_SEARCH_API_KEY"], "strict": False}, + "index_name": "default", + "embedding_dimension": 768, + "metadata_fields": None, + "vector_search_configuration": { + "profiles": [ + {"name": "default-vector-config", "algorithm_configuration_name": "cosine-algorithm-config"} + ], + "algorithms": [ + { + "name": "cosine-algorithm-config", + "kind": "hnsw", + "parameters": {"m": 4, "ef_construction": 400, "ef_search": 500, "metric": "cosine"}, + } + ], + }, + "hosts": "some fake host", + }, + }, + "filter_policy": "replace", + }, + } + + +def test_from_dict(): + data = { + "type": "haystack_integrations.components.retrievers.azure_ai_search.hybrid_retriever.AzureAISearchHybridRetriever", # noqa: E501 + "init_parameters": { + "filters": {}, + "top_k": 10, + "document_store": { + "type": "haystack_integrations.document_stores.azure_ai_search.document_store.AzureAISearchDocumentStore", # noqa: E501 + "init_parameters": { + "azure_endpoint": { + "type": "env_var", + "env_vars": ["AZURE_SEARCH_SERVICE_ENDPOINT"], + "strict": True, + }, + "api_key": {"type": "env_var", "env_vars": ["AZURE_SEARCH_API_KEY"], "strict": False}, + "index_name": "default", + "embedding_dimension": 768, + "metadata_fields": None, + "vector_search_configuration": DEFAULT_VECTOR_SEARCH, + "hosts": "some fake host", + }, + }, + "filter_policy": "replace", + }, + } + retriever = AzureAISearchHybridRetriever.from_dict(data) + assert isinstance(retriever._document_store, AzureAISearchDocumentStore) + assert retriever._filters == {} + assert retriever._top_k == 10 + assert retriever._filter_policy == FilterPolicy.REPLACE + + +def test_run(): + mock_store = Mock(spec=AzureAISearchDocumentStore) + mock_store._hybrid_retrieval.return_value = [Document(content="Test doc", embedding=[0.1, 0.2])] + retriever = AzureAISearchHybridRetriever(document_store=mock_store) + res = retriever.run(query_embedding=[0.5, 0.7], query="Test query") + mock_store._hybrid_retrieval.assert_called_once_with( + query="Test query", + query_embedding=[0.5, 0.7], + filters="", + top_k=10, + ) + assert len(res) == 1 + assert len(res["documents"]) == 1 + assert res["documents"][0].content == "Test doc" + assert res["documents"][0].embedding == [0.1, 0.2] + + +def test_run_init_params(): + mock_store = Mock(spec=AzureAISearchDocumentStore) + mock_store._hybrid_retrieval.return_value = [Document(content="Test doc", embedding=[0.1, 0.2])] + retriever = AzureAISearchHybridRetriever( + document_store=mock_store, + filters={"field": "type", "operator": "==", "value": "article"}, + top_k=11, + ) + res = retriever.run(query_embedding=[0.5, 0.7], query="Test query") + mock_store._hybrid_retrieval.assert_called_once_with( + query="Test query", + query_embedding=[0.5, 0.7], + filters="type eq 'article'", + top_k=11, + ) + assert len(res) == 1 + assert len(res["documents"]) == 1 + assert res["documents"][0].content == "Test doc" + assert res["documents"][0].embedding == [0.1, 0.2] + + +def test_run_time_params(): + mock_store = Mock(spec=AzureAISearchDocumentStore) + mock_store._hybrid_retrieval.return_value = [Document(content="Test doc", embedding=[0.1, 0.2])] + retriever = AzureAISearchHybridRetriever( + document_store=mock_store, + filters={"field": "type", "operator": "==", "value": "article"}, + top_k=11, + select="name", + ) + res = retriever.run( + query_embedding=[0.5, 0.7], + query="Test query", + filters={"field": "type", "operator": "==", "value": "book"}, + top_k=9, + ) + mock_store._hybrid_retrieval.assert_called_once_with( + query="Test query", + query_embedding=[0.5, 0.7], + filters="type eq 'book'", + top_k=9, + select="name", + ) + assert len(res) == 1 + assert len(res["documents"]) == 1 + assert res["documents"][0].content == "Test doc" + assert res["documents"][0].embedding == [0.1, 0.2] + + +@pytest.mark.skipif( + not os.environ.get("AZURE_SEARCH_SERVICE_ENDPOINT", None) and not os.environ.get("AZURE_SEARCH_API_KEY", None), + reason="Missing AZURE_SEARCH_SERVICE_ENDPOINT or AZURE_SEARCH_API_KEY.", +) +@pytest.mark.integration +class TestRetriever: + + def test_run(self, document_store: AzureAISearchDocumentStore): + docs = [Document(id="1")] + document_store.write_documents(docs) + retriever = AzureAISearchHybridRetriever(document_store=document_store) + res = retriever.run(query="Test document", query_embedding=[0.1] * 768) + assert res["documents"] == docs + + def test_hybrid_retrieval(self, document_store: AzureAISearchDocumentStore): + query_embedding = [0.1] * 768 + most_similar_embedding = [0.8] * 768 + second_best_embedding = [0.8] * 200 + [0.1] * 300 + [0.2] * 268 + another_embedding = rand(768).tolist() + + docs = [ + Document(content="This is first document", embedding=most_similar_embedding), + Document(content="This is second document", embedding=second_best_embedding), + Document(content="This is third document", embedding=another_embedding), + ] + + document_store.write_documents(docs) + retriever = AzureAISearchHybridRetriever(document_store=document_store) + results = retriever.run(query="This is first document", query_embedding=query_embedding) + assert results["documents"][0].content == "This is first document" + + def test_empty_query_embedding(self, document_store: AzureAISearchDocumentStore): + query_embedding: List[float] = [] + with pytest.raises(ValueError): + document_store._hybrid_retrieval(query="", query_embedding=query_embedding) + + def test_query_embedding_wrong_dimension(self, document_store: AzureAISearchDocumentStore): + query_embedding = [0.1] * 4 + with pytest.raises(HttpResponseError): + document_store._hybrid_retrieval(query="", query_embedding=query_embedding) From b42ec5c4c38a2fbcb6528c9631b4a633e92dab04 Mon Sep 17 00:00:00 2001 From: Stefano Fiorucci Date: Thu, 21 Nov 2024 10:06:22 +0100 Subject: [PATCH 091/229] feat: Pgvector - recreate the connection if it is no longer valid (#1202) * try refreshing connection * small improvements * rename method --- .../pgvector/document_store.py | 46 +++++++++++++++++-- .../pgvector/tests/test_document_store.py | 19 ++++++++ 2 files changed, 60 insertions(+), 5 deletions(-) diff --git a/integrations/pgvector/src/haystack_integrations/document_stores/pgvector/document_store.py b/integrations/pgvector/src/haystack_integrations/document_stores/pgvector/document_store.py index 8e9c0f2fc..6682c2fee 100644 --- a/integrations/pgvector/src/haystack_integrations/document_stores/pgvector/document_store.py +++ b/integrations/pgvector/src/haystack_integrations/document_stores/pgvector/document_store.py @@ -156,29 +156,41 @@ def __init__( self._connection = None self._cursor = None self._dict_cursor = None + self._table_initialized = False @property def cursor(self): - if self._cursor is None: + if self._cursor is None or not self._connection_is_valid(self._connection): self._create_connection() return self._cursor @property def dict_cursor(self): - if self._dict_cursor is None: + if self._dict_cursor is None or not self._connection_is_valid(self._connection): self._create_connection() return self._dict_cursor @property def connection(self): - if self._connection is None: + if self._connection is None or not self._connection_is_valid(self._connection): self._create_connection() return self._connection def _create_connection(self): + """ + Internal method to create a connection to the PostgreSQL database. + """ + + # close the connection if it already exists + if self._connection: + try: + self._connection.close() + except Error as e: + logger.debug("Failed to close connection: %s", str(e)) + conn_str = self.connection_string.resolve_value() or "" connection = connect(conn_str) connection.autocommit = True @@ -189,16 +201,40 @@ def _create_connection(self): self._cursor = self._connection.cursor() self._dict_cursor = self._connection.cursor(row_factory=dict_row) - # Init schema + if not self._table_initialized: + self._initialize_table() + + return self._connection + + def _initialize_table(self): + """ + Internal method to initialize the table. + """ if self.recreate_table: self.delete_table() + self._create_table_if_not_exists() self._create_keyword_index_if_not_exists() if self.search_strategy == "hnsw": self._handle_hnsw() - return self._connection + self._table_initialized = True + + @staticmethod + def _connection_is_valid(connection): + """ + Internal method to check if the connection is still valid. + """ + + # implementation inspired to psycopg pool + # https://github.com/psycopg/psycopg/blob/d38cf7798b0c602ff43dac9f20bbab96237a9c38/psycopg_pool/psycopg_pool/pool.py#L528 + + try: + connection.execute("") + except Error: + return False + return True def to_dict(self) -> Dict[str, Any]: """ diff --git a/integrations/pgvector/tests/test_document_store.py b/integrations/pgvector/tests/test_document_store.py index 4af4fc8de..c6f160f91 100644 --- a/integrations/pgvector/tests/test_document_store.py +++ b/integrations/pgvector/tests/test_document_store.py @@ -41,6 +41,25 @@ def test_write_dataframe(self, document_store: PgvectorDocumentStore): retrieved_docs = document_store.filter_documents() assert retrieved_docs == docs + def test_connection_check_and_recreation(self, document_store: PgvectorDocumentStore): + original_connection = document_store.connection + + with patch.object(PgvectorDocumentStore, "_connection_is_valid", return_value=False): + new_connection = document_store.connection + + # verify that a new connection is created + assert new_connection is not original_connection + assert document_store._connection == new_connection + assert original_connection.closed + + assert document_store._cursor is not None + assert document_store._dict_cursor is not None + + # test with new connection + with patch.object(PgvectorDocumentStore, "_connection_is_valid", return_value=True): + same_connection = document_store.connection + assert same_connection is document_store._connection + @pytest.mark.usefixtures("patches_for_unit_tests") def test_init(monkeypatch): From 96e3951003a361e94259d269a7fa51a5140f726e Mon Sep 17 00:00:00 2001 From: HaystackBot Date: Thu, 21 Nov 2024 09:09:05 +0000 Subject: [PATCH 092/229] Update the changelog --- integrations/pgvector/CHANGELOG.md | 50 +++++++++++++++++++++++++++--- 1 file changed, 45 insertions(+), 5 deletions(-) diff --git a/integrations/pgvector/CHANGELOG.md b/integrations/pgvector/CHANGELOG.md index 0fe5f4fa4..7c8be2340 100644 --- a/integrations/pgvector/CHANGELOG.md +++ b/integrations/pgvector/CHANGELOG.md @@ -1,24 +1,43 @@ # Changelog -## [integrations/pgvector-v1.0.0] - 2024-09-12 +## [integrations/pgvector-v1.1.0] - 2024-11-21 ### 🚀 Features - Add filter_policy to pgvector integration (#820) +- Add schema support to pgvector document store. (#1095) +- Pgvector - recreate the connection if it is no longer valid (#1202) ### 🐛 Bug Fixes - `PgVector` - Fallback to default filter policy when deserializing retrievers without the init parameter (#900) +### 📚 Documentation + +- Explain different connection string formats in the docstring (#1132) + ### 🧪 Testing - Do not retry tests in `hatch run test` command (#954) -### ⚙️ Miscellaneous Tasks +### ⚙️ CI - Retry tests to reduce flakyness (#836) +- Adopt uv as installer (#1142) + +### 🧹 Chores + - Update ruff invocation to include check parameter (#853) - PgVector - remove legacy filter support (#1068) +- Update changelog after removing legacy filters (#1083) +- Update ruff linting scripts and settings (#1105) + +### 🌀 Miscellaneous + +- Ci: install `pytest-rerunfailures` where needed; add retry config to `test-cov` script (#845) +- Chore: Minor retriever pydoc fix (#884) +- Chore: Update pgvector test for the new `apply_filter_policy` usage (#970) +- Chore: pgvector ruff update, don't ruff tests (#984) ## [integrations/pgvector-v0.4.0] - 2024-06-20 @@ -27,6 +46,11 @@ - Defer the database connection to when it's needed (#773) - Add customizable index names for pgvector (#818) +### 🌀 Miscellaneous + +- Docs: add missing api references (#728) +- [deepset-ai/haystack-core-integrations#727] (#738) + ## [integrations/pgvector-v0.2.0] - 2024-05-08 ### 🚀 Features @@ -38,19 +62,35 @@ - Fix order of API docs (#447) -This PR will also push the docs to Readme - ### 📚 Documentation - Update category slug (#442) - Disable-class-def (#556) +### 🌀 Miscellaneous + +- Pgvector - review docstrings and API reference (#502) +- Refactor tests (#574) +- Remove references to Python 3.7 (#601) +- Make Document Stores initially skip `SparseEmbedding` (#606) +- Chore: add license classifiers (#680) +- Type hints in pgvector document store updated for 3.8 compability (#704) +- Chore: change the pydoc renderer class (#718) + ## [integrations/pgvector-v0.1.0] - 2024-02-14 ### 🐛 Bug Fixes -- Fix linting (#328) +- Pgvector: fix linting (#328) +### 🌀 Miscellaneous +- Pgvector Document Store - minimal implementation (#239) +- Pgvector - filters (#257) +- Pgvector - embedding retrieval (#298) +- Pgvector - Embedding Retriever (#320) +- Pgvector: generate API docs (#325) +- Pgvector: add an example (#334) +- Adopt `Secret` to pgvector (#402) From 16bc80f0d49e3c50b90c17e55b7fb7118a4b8485 Mon Sep 17 00:00:00 2001 From: Sebastian Husch Lee Date: Thu, 21 Nov 2024 10:33:39 +0100 Subject: [PATCH 093/229] feat: Improvements to NvidiaRanker and adding user input timeout (#1193) * Lots of fixes * Remove unused import * Fix readme * linting * Add more logging * Follow same private/public attribute as other components * Add tests * Linting * Add another test * Add timeout to to_dict * Update integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/document_embedder.py Co-authored-by: David S. Batista --------- Co-authored-by: David S. Batista --- integrations/nvidia/README.md | 2 +- .../embedders/nvidia/document_embedder.py | 22 ++++- .../embedders/nvidia/text_embedder.py | 29 +++++- .../components/generators/nvidia/generator.py | 11 ++- .../components/rankers/nvidia/ranker.py | 99 ++++++++++++++----- .../utils/nvidia/nim_backend.py | 36 ++++--- .../nvidia/tests/test_document_embedder.py | 25 ++++- integrations/nvidia/tests/test_generator.py | 13 +++ integrations/nvidia/tests/test_ranker.py | 90 +++++++++++++---- .../nvidia/tests/test_text_embedder.py | 18 ++++ 10 files changed, 279 insertions(+), 66 deletions(-) diff --git a/integrations/nvidia/README.md b/integrations/nvidia/README.md index e28f0ede9..558c34d28 100644 --- a/integrations/nvidia/README.md +++ b/integrations/nvidia/README.md @@ -38,7 +38,7 @@ hatch run test To only run unit tests: ``` -hatch run test -m"not integration" +hatch run test -m "not integration" ``` To run the linters `ruff` and `mypy`: diff --git a/integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/document_embedder.py b/integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/document_embedder.py index 606ec78fd..6519efbab 100644 --- a/integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/document_embedder.py +++ b/integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/document_embedder.py @@ -2,16 +2,19 @@ # # SPDX-License-Identifier: Apache-2.0 +import os import warnings from typing import Any, Dict, List, Optional, Tuple, Union -from haystack import Document, component, default_from_dict, default_to_dict +from haystack import Document, component, default_from_dict, default_to_dict, logging from haystack.utils import Secret, deserialize_secrets_inplace from tqdm import tqdm from haystack_integrations.components.embedders.nvidia.truncate import EmbeddingTruncateMode from haystack_integrations.utils.nvidia import NimBackend, is_hosted, url_validation +logger = logging.getLogger(__name__) + _DEFAULT_API_URL = "https://ai.api.nvidia.com/v1/retrieval/nvidia" @@ -47,6 +50,7 @@ def __init__( meta_fields_to_embed: Optional[List[str]] = None, embedding_separator: str = "\n", truncate: Optional[Union[EmbeddingTruncateMode, str]] = None, + timeout: Optional[float] = None, ): """ Create a NvidiaTextEmbedder component. @@ -74,8 +78,11 @@ def __init__( :param embedding_separator: Separator used to concatenate the meta fields to the Document text. :param truncate: - Specifies how inputs longer that the maximum token length should be truncated. + Specifies how inputs longer than the maximum token length should be truncated. If None the behavior is model-dependent, see the official documentation for more information. + :param timeout: + Timeout for request calls, if not set it is inferred from the `NVIDIA_TIMEOUT` environment variable + or set to 60 by default. """ self.api_key = api_key @@ -98,6 +105,10 @@ def __init__( if is_hosted(api_url) and not self.model: # manually set default model self.model = "nvidia/nv-embedqa-e5-v5" + if timeout is None: + timeout = float(os.environ.get("NVIDIA_TIMEOUT", 60.0)) + self.timeout = timeout + def default_model(self): """Set default model in local NIM mode.""" valid_models = [ @@ -128,10 +139,11 @@ def warm_up(self): if self.truncate is not None: model_kwargs["truncate"] = str(self.truncate) self.backend = NimBackend( - self.model, + model=self.model, api_url=self.api_url, api_key=self.api_key, model_kwargs=model_kwargs, + timeout=self.timeout, ) self._initialized = True @@ -158,6 +170,7 @@ def to_dict(self) -> Dict[str, Any]: meta_fields_to_embed=self.meta_fields_to_embed, embedding_separator=self.embedding_separator, truncate=str(self.truncate) if self.truncate is not None else None, + timeout=self.timeout, ) @classmethod @@ -238,8 +251,7 @@ def run(self, documents: List[Document]): for doc in documents: if not doc.content: - msg = f"Document '{doc.id}' has no content to embed." - raise ValueError(msg) + logger.warning(f"Document '{doc.id}' has no content to embed.") texts_to_embed = self._prepare_texts_to_embed(documents) embeddings, metadata = self._embed_batch(texts_to_embed, self.batch_size) diff --git a/integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/text_embedder.py b/integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/text_embedder.py index 4b7072f33..a93aa8caa 100644 --- a/integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/text_embedder.py +++ b/integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/text_embedder.py @@ -2,15 +2,18 @@ # # SPDX-License-Identifier: Apache-2.0 +import os import warnings from typing import Any, Dict, List, Optional, Union -from haystack import component, default_from_dict, default_to_dict +from haystack import component, default_from_dict, default_to_dict, logging from haystack.utils import Secret, deserialize_secrets_inplace from haystack_integrations.components.embedders.nvidia.truncate import EmbeddingTruncateMode from haystack_integrations.utils.nvidia import NimBackend, is_hosted, url_validation +logger = logging.getLogger(__name__) + _DEFAULT_API_URL = "https://ai.api.nvidia.com/v1/retrieval/nvidia" @@ -44,6 +47,7 @@ def __init__( prefix: str = "", suffix: str = "", truncate: Optional[Union[EmbeddingTruncateMode, str]] = None, + timeout: Optional[float] = None, ): """ Create a NvidiaTextEmbedder component. @@ -64,6 +68,9 @@ def __init__( :param truncate: Specifies how inputs longer that the maximum token length should be truncated. If None the behavior is model-dependent, see the official documentation for more information. + :param timeout: + Timeout for request calls, if not set it is inferred from the `NVIDIA_TIMEOUT` environment variable + or set to 60 by default. """ self.api_key = api_key @@ -82,6 +89,10 @@ def __init__( if is_hosted(api_url) and not self.model: # manually set default model self.model = "nvidia/nv-embedqa-e5-v5" + if timeout is None: + timeout = float(os.environ.get("NVIDIA_TIMEOUT", 60.0)) + self.timeout = timeout + def default_model(self): """Set default model in local NIM mode.""" valid_models = [ @@ -89,6 +100,12 @@ def default_model(self): ] name = next(iter(valid_models), None) if name: + logger.warning( + "Default model is set as: {model_name}. \n" + "Set model using model parameter. \n" + "To get available models use available_models property.", + model_name=name, + ) warnings.warn( f"Default model is set as: {name}. \n" "Set model using model parameter. \n" @@ -112,10 +129,11 @@ def warm_up(self): if self.truncate is not None: model_kwargs["truncate"] = str(self.truncate) self.backend = NimBackend( - self.model, + model=self.model, api_url=self.api_url, api_key=self.api_key, model_kwargs=model_kwargs, + timeout=self.timeout, ) self._initialized = True @@ -138,6 +156,7 @@ def to_dict(self) -> Dict[str, Any]: prefix=self.prefix, suffix=self.suffix, truncate=str(self.truncate) if self.truncate is not None else None, + timeout=self.timeout, ) @classmethod @@ -150,7 +169,9 @@ def from_dict(cls, data: Dict[str, Any]) -> "NvidiaTextEmbedder": :returns: The deserialized component. """ - deserialize_secrets_inplace(data["init_parameters"], keys=["api_key"]) + init_parameters = data.get("init_parameters", {}) + if init_parameters: + deserialize_secrets_inplace(data["init_parameters"], keys=["api_key"]) return default_from_dict(cls, data) @component.output_types(embedding=List[float], meta=Dict[str, Any]) @@ -162,7 +183,7 @@ def run(self, text: str): The text to embed. :returns: A dictionary with the following keys and values: - - `embedding` - Embeddng of the text. + - `embedding` - Embedding of the text. - `meta` - Metadata on usage statistics, etc. :raises RuntimeError: If the component was not initialized. diff --git a/integrations/nvidia/src/haystack_integrations/components/generators/nvidia/generator.py b/integrations/nvidia/src/haystack_integrations/components/generators/nvidia/generator.py index 5bf71a9e1..5047d0682 100644 --- a/integrations/nvidia/src/haystack_integrations/components/generators/nvidia/generator.py +++ b/integrations/nvidia/src/haystack_integrations/components/generators/nvidia/generator.py @@ -2,6 +2,7 @@ # # SPDX-License-Identifier: Apache-2.0 +import os import warnings from typing import Any, Dict, List, Optional @@ -49,6 +50,7 @@ def __init__( api_url: str = _DEFAULT_API_URL, api_key: Optional[Secret] = Secret.from_env_var("NVIDIA_API_KEY"), model_arguments: Optional[Dict[str, Any]] = None, + timeout: Optional[float] = None, ): """ Create a NvidiaGenerator component. @@ -70,6 +72,9 @@ def __init__( specific to a model. Search your model in the [NVIDIA NIM](https://ai.nvidia.com) to find the arguments it accepts. + :param timeout: + Timeout for request calls, if not set it is inferred from the `NVIDIA_TIMEOUT` environment variable + or set to 60 by default. """ self._model = model self._api_url = url_validation(api_url, _DEFAULT_API_URL, ["v1/chat/completions"]) @@ -79,6 +84,9 @@ def __init__( self._backend: Optional[Any] = None self.is_hosted = is_hosted(api_url) + if timeout is None: + timeout = float(os.environ.get("NVIDIA_TIMEOUT", 60.0)) + self.timeout = timeout def default_model(self): """Set default model in local NIM mode.""" @@ -110,10 +118,11 @@ def warm_up(self): msg = "API key is required for hosted NVIDIA NIMs." raise ValueError(msg) self._backend = NimBackend( - self._model, + model=self._model, api_url=self._api_url, api_key=self._api_key, model_kwargs=self._model_arguments, + timeout=self.timeout, ) if not self.is_hosted and not self._model: diff --git a/integrations/nvidia/src/haystack_integrations/components/rankers/nvidia/ranker.py b/integrations/nvidia/src/haystack_integrations/components/rankers/nvidia/ranker.py index 9938b37d1..66203a490 100644 --- a/integrations/nvidia/src/haystack_integrations/components/rankers/nvidia/ranker.py +++ b/integrations/nvidia/src/haystack_integrations/components/rankers/nvidia/ranker.py @@ -2,6 +2,7 @@ # # SPDX-License-Identifier: Apache-2.0 +import os import warnings from typing import Any, Dict, List, Optional, Union @@ -58,6 +59,11 @@ def __init__( api_url: Optional[str] = None, api_key: Optional[Secret] = Secret.from_env_var("NVIDIA_API_KEY"), top_k: int = 5, + query_prefix: str = "", + document_prefix: str = "", + meta_fields_to_embed: Optional[List[str]] = None, + embedding_separator: str = "\n", + timeout: Optional[float] = None, ): """ Create a NvidiaRanker component. @@ -72,6 +78,19 @@ def __init__( Custom API URL for the NVIDIA NIM. :param top_k: Number of documents to return. + :param query_prefix: + A string to add at the beginning of the query text before ranking. + Use it to prepend the text with an instruction, as required by reranking models like `bge`. + :param document_prefix: + A string to add at the beginning of each document before ranking. You can use it to prepend the document + with an instruction, as required by embedding models like `bge`. + :param meta_fields_to_embed: + List of metadata fields to embed with the document. + :param embedding_separator: + Separator to concatenate metadata fields to the document. + :param timeout: + Timeout for request calls, if not set it is inferred from the `NVIDIA_TIMEOUT` environment variable + or set to 60 by default. """ if model is not None and not isinstance(model, str): msg = "Ranker expects the `model` parameter to be a string." @@ -86,27 +105,35 @@ def __init__( raise TypeError(msg) # todo: detect default in non-hosted case (when api_url is provided) - self._model = model or _DEFAULT_MODEL - self._truncate = truncate - self._api_key = api_key + self.model = model or _DEFAULT_MODEL + self.truncate = truncate + self.api_key = api_key # if no api_url is provided, we're using a hosted model and can # - assume the default url will work, because there's only one model # - assume we won't call backend.models() if api_url is not None: - self._api_url = url_validation(api_url, None, ["v1/ranking"]) - self._endpoint = None # we let backend.rank() handle the endpoint + self.api_url = url_validation(api_url, None, ["v1/ranking"]) + self.endpoint = None # we let backend.rank() handle the endpoint else: - if self._model not in _MODEL_ENDPOINT_MAP: + if self.model not in _MODEL_ENDPOINT_MAP: msg = f"Model '{model}' is unknown. Please provide an api_url to access it." raise ValueError(msg) - self._api_url = None # we handle the endpoint - self._endpoint = _MODEL_ENDPOINT_MAP[self._model] + self.api_url = None # we handle the endpoint + self.endpoint = _MODEL_ENDPOINT_MAP[self.model] if api_key is None: self._api_key = Secret.from_env_var("NVIDIA_API_KEY") - self._top_k = top_k + self.top_k = top_k self._initialized = False self._backend: Optional[Any] = None + self.query_prefix = query_prefix + self.document_prefix = document_prefix + self.meta_fields_to_embed = meta_fields_to_embed or [] + self.embedding_separator = embedding_separator + if timeout is None: + timeout = float(os.environ.get("NVIDIA_TIMEOUT", 60.0)) + self.timeout = timeout + def to_dict(self) -> Dict[str, Any]: """ Serialize the ranker to a dictionary. @@ -115,11 +142,16 @@ def to_dict(self) -> Dict[str, Any]: """ return default_to_dict( self, - model=self._model, - top_k=self._top_k, - truncate=self._truncate, - api_url=self._api_url, - api_key=self._api_key.to_dict() if self._api_key else None, + model=self.model, + top_k=self.top_k, + truncate=self.truncate, + api_url=self.api_url, + api_key=self.api_key.to_dict() if self.api_key else None, + query_prefix=self.query_prefix, + document_prefix=self.document_prefix, + meta_fields_to_embed=self.meta_fields_to_embed, + embedding_separator=self.embedding_separator, + timeout=self.timeout, ) @classmethod @@ -143,18 +175,31 @@ def warm_up(self): """ if not self._initialized: model_kwargs = {} - if self._truncate is not None: - model_kwargs.update(truncate=str(self._truncate)) + if self.truncate is not None: + model_kwargs.update(truncate=str(self.truncate)) self._backend = NimBackend( - self._model, - api_url=self._api_url, - api_key=self._api_key, + model=self.model, + api_url=self.api_url, + api_key=self.api_key, model_kwargs=model_kwargs, + timeout=self.timeout, ) - if not self._model: - self._model = _DEFAULT_MODEL + if not self.model: + self.model = _DEFAULT_MODEL self._initialized = True + def _prepare_documents_to_embed(self, documents: List[Document]) -> List[str]: + document_texts = [] + for doc in documents: + meta_values_to_embed = [ + str(doc.meta[key]) + for key in self.meta_fields_to_embed + if key in doc.meta and doc.meta[key] # noqa: RUF019 + ] + text_to_embed = self.embedding_separator.join([*meta_values_to_embed, doc.content or ""]) + document_texts.append(self.document_prefix + text_to_embed) + return document_texts + @component.output_types(documents=List[Document]) def run( self, @@ -193,18 +238,22 @@ def run( if len(documents) == 0: return {"documents": []} - top_k = top_k if top_k is not None else self._top_k + top_k = top_k if top_k is not None else self.top_k if top_k < 1: logger.warning("top_k should be at least 1, returning nothing") warnings.warn("top_k should be at least 1, returning nothing", stacklevel=2) return {"documents": []} assert self._backend is not None + + query_text = self.query_prefix + query + document_texts = self._prepare_documents_to_embed(documents=documents) + # rank result is list[{index: int, logit: float}] sorted by logit sorted_indexes_and_scores = self._backend.rank( - query, - documents, - endpoint=self._endpoint, + query_text=query_text, + document_texts=document_texts, + endpoint=self.endpoint, ) sorted_documents = [] for item in sorted_indexes_and_scores[:top_k]: diff --git a/integrations/nvidia/src/haystack_integrations/utils/nvidia/nim_backend.py b/integrations/nvidia/src/haystack_integrations/utils/nvidia/nim_backend.py index 0279cf608..15b35e4b2 100644 --- a/integrations/nvidia/src/haystack_integrations/utils/nvidia/nim_backend.py +++ b/integrations/nvidia/src/haystack_integrations/utils/nvidia/nim_backend.py @@ -2,14 +2,17 @@ # # SPDX-License-Identifier: Apache-2.0 +import os from dataclasses import dataclass, field from typing import Any, Dict, List, Optional, Tuple import requests -from haystack import Document +from haystack import logging from haystack.utils import Secret -REQUEST_TIMEOUT = 60 +logger = logging.getLogger(__name__) + +REQUEST_TIMEOUT = 60.0 @dataclass @@ -35,6 +38,7 @@ def __init__( api_url: str, api_key: Optional[Secret] = Secret.from_env_var("NVIDIA_API_KEY"), model_kwargs: Optional[Dict[str, Any]] = None, + timeout: Optional[float] = None, ): headers = { "Content-Type": "application/json", @@ -50,6 +54,9 @@ def __init__( self.model = model self.api_url = api_url self.model_kwargs = model_kwargs or {} + if timeout is None: + timeout = float(os.environ.get("NVIDIA_TIMEOUT", REQUEST_TIMEOUT)) + self.timeout = timeout def embed(self, texts: List[str]) -> Tuple[List[List[float]], Dict[str, Any]]: url = f"{self.api_url}/embeddings" @@ -62,10 +69,11 @@ def embed(self, texts: List[str]) -> Tuple[List[List[float]], Dict[str, Any]]: "input": texts, **self.model_kwargs, }, - timeout=REQUEST_TIMEOUT, + timeout=self.timeout, ) res.raise_for_status() except requests.HTTPError as e: + logger.error("Error when calling NIM embedding endpoint: Error - {error}", error=e.response.text) msg = f"Failed to query embedding endpoint: Error - {e.response.text}" raise ValueError(msg) from e @@ -94,10 +102,11 @@ def generate(self, prompt: str) -> Tuple[List[str], List[Dict[str, Any]]]: ], **self.model_kwargs, }, - timeout=REQUEST_TIMEOUT, + timeout=self.timeout, ) res.raise_for_status() except requests.HTTPError as e: + logger.error("Error when calling NIM chat completion endpoint: Error - {error}", error=e.response.text) msg = f"Failed to query chat completion endpoint: Error - {e.response.text}" raise ValueError(msg) from e @@ -132,21 +141,22 @@ def models(self) -> List[Model]: res = self.session.get( url, - timeout=REQUEST_TIMEOUT, + timeout=self.timeout, ) res.raise_for_status() data = res.json()["data"] models = [Model(element["id"]) for element in data if "id" in element] if not models: + logger.error("No hosted model were found at URL '{u}'.", u=url) msg = f"No hosted model were found at URL '{url}'." raise ValueError(msg) return models def rank( self, - query: str, - documents: List[Document], + query_text: str, + document_texts: List[str], endpoint: Optional[str] = None, ) -> List[Dict[str, Any]]: url = endpoint or f"{self.api_url}/ranking" @@ -156,18 +166,22 @@ def rank( url, json={ "model": self.model, - "query": {"text": query}, - "passages": [{"text": doc.content} for doc in documents], + "query": {"text": query_text}, + "passages": [{"text": text} for text in document_texts], **self.model_kwargs, }, - timeout=REQUEST_TIMEOUT, + timeout=self.timeout, ) res.raise_for_status() except requests.HTTPError as e: + logger.error("Error when calling NIM ranking endpoint: Error - {error}", error=e.response.text) msg = f"Failed to rank endpoint: Error - {e.response.text}" raise ValueError(msg) from e data = res.json() - assert "rankings" in data, f"Expected 'rankings' in response, got {data}" + if "rankings" not in data: + logger.error("Expected 'rankings' in response, got {d}", d=data) + msg = f"Expected 'rankings' in response, got {data}" + raise ValueError(msg) return data["rankings"] diff --git a/integrations/nvidia/tests/test_document_embedder.py b/integrations/nvidia/tests/test_document_embedder.py index 7e0e02f3d..8c01f0759 100644 --- a/integrations/nvidia/tests/test_document_embedder.py +++ b/integrations/nvidia/tests/test_document_embedder.py @@ -75,6 +75,7 @@ def test_to_dict(self, monkeypatch): "meta_fields_to_embed": [], "embedding_separator": "\n", "truncate": None, + "timeout": 60.0, }, } @@ -90,6 +91,7 @@ def test_to_dict_with_custom_init_parameters(self, monkeypatch): meta_fields_to_embed=["test_field"], embedding_separator=" | ", truncate=EmbeddingTruncateMode.END, + timeout=45.0, ) data = component.to_dict() assert data == { @@ -105,6 +107,7 @@ def test_to_dict_with_custom_init_parameters(self, monkeypatch): "meta_fields_to_embed": ["test_field"], "embedding_separator": " | ", "truncate": "END", + "timeout": 45.0, }, } @@ -123,6 +126,7 @@ def test_from_dict(self, monkeypatch): "meta_fields_to_embed": ["test_field"], "embedding_separator": " | ", "truncate": "START", + "timeout": 45.0, }, } component = NvidiaDocumentEmbedder.from_dict(data) @@ -135,6 +139,7 @@ def test_from_dict(self, monkeypatch): assert component.meta_fields_to_embed == ["test_field"] assert component.embedding_separator == " | " assert component.truncate == EmbeddingTruncateMode.START + assert component.timeout == 45.0 def test_from_dict_defaults(self, monkeypatch): monkeypatch.setenv("NVIDIA_API_KEY", "fake-api-key") @@ -152,6 +157,7 @@ def test_from_dict_defaults(self, monkeypatch): assert component.meta_fields_to_embed == [] assert component.embedding_separator == "\n" assert component.truncate is None + assert component.timeout == 60.0 def test_prepare_texts_to_embed_w_metadata(self): documents = [ @@ -347,7 +353,7 @@ def test_run_wrong_input_format(self): with pytest.raises(TypeError, match="NvidiaDocumentEmbedder expects a list of Documents as input"): embedder.run(documents=list_integers_input) - def test_run_empty_document(self): + def test_run_empty_document(self, caplog): model = "playground_nvolveqa_40k" api_key = Secret.from_token("fake-api-key") embedder = NvidiaDocumentEmbedder(model, api_key=api_key) @@ -355,8 +361,10 @@ def test_run_empty_document(self): embedder.warm_up() embedder.backend = MockBackend(model=model, api_key=api_key) - with pytest.raises(ValueError, match="no content to embed"): + # Write check using caplog that a logger.warning is raised + with caplog.at_level("WARNING"): embedder.run(documents=[Document(content="")]) + assert "has no content to embed." in caplog.text def test_run_on_empty_list(self): model = "playground_nvolveqa_40k" @@ -372,6 +380,19 @@ def test_run_on_empty_list(self): assert result["documents"] is not None assert not result["documents"] # empty list + def test_setting_timeout(self, monkeypatch): + monkeypatch.setenv("NVIDIA_API_KEY", "fake-api-key") + embedder = NvidiaDocumentEmbedder(timeout=10.0) + embedder.warm_up() + assert embedder.backend.timeout == 10.0 + + def test_setting_timeout_env(self, monkeypatch): + monkeypatch.setenv("NVIDIA_API_KEY", "fake-api-key") + monkeypatch.setenv("NVIDIA_TIMEOUT", "45") + embedder = NvidiaDocumentEmbedder() + embedder.warm_up() + assert embedder.backend.timeout == 45.0 + @pytest.mark.skipif( not os.environ.get("NVIDIA_API_KEY", None), reason="Export an env var called NVIDIA_API_KEY containing the Nvidia API key to run this test.", diff --git a/integrations/nvidia/tests/test_generator.py b/integrations/nvidia/tests/test_generator.py index 055830ae5..414de4884 100644 --- a/integrations/nvidia/tests/test_generator.py +++ b/integrations/nvidia/tests/test_generator.py @@ -124,6 +124,19 @@ def test_to_dict_with_custom_init_parameters(self, monkeypatch): }, } + def test_setting_timeout(self, monkeypatch): + monkeypatch.setenv("NVIDIA_API_KEY", "fake-api-key") + generator = NvidiaGenerator(timeout=10.0) + generator.warm_up() + assert generator._backend.timeout == 10.0 + + def test_setting_timeout_env(self, monkeypatch): + monkeypatch.setenv("NVIDIA_API_KEY", "fake-api-key") + monkeypatch.setenv("NVIDIA_TIMEOUT", "45") + generator = NvidiaGenerator() + generator.warm_up() + assert generator._backend.timeout == 45.0 + @pytest.mark.skipif( not os.environ.get("NVIDIA_NIM_GENERATOR_MODEL", None) or not os.environ.get("NVIDIA_NIM_ENDPOINT_URL", None), reason="Export an env var called NVIDIA_NIM_GENERATOR_MODEL containing the hosted model name and " diff --git a/integrations/nvidia/tests/test_ranker.py b/integrations/nvidia/tests/test_ranker.py index d66bb0f65..3d93dc028 100644 --- a/integrations/nvidia/tests/test_ranker.py +++ b/integrations/nvidia/tests/test_ranker.py @@ -19,8 +19,8 @@ class TestNvidiaRanker: def test_init_default(self, monkeypatch): monkeypatch.setenv("NVIDIA_API_KEY", "fake-api-key") client = NvidiaRanker() - assert client._model == _DEFAULT_MODEL - assert client._api_key == Secret.from_env_var("NVIDIA_API_KEY") + assert client.model == _DEFAULT_MODEL + assert client.api_key == Secret.from_env_var("NVIDIA_API_KEY") def test_init_with_parameters(self): client = NvidiaRanker( @@ -29,10 +29,10 @@ def test_init_with_parameters(self): top_k=3, truncate="END", ) - assert client._api_key == Secret.from_token("fake-api-key") - assert client._model == _DEFAULT_MODEL - assert client._top_k == 3 - assert client._truncate == RankerTruncateMode.END + assert client.api_key == Secret.from_token("fake-api-key") + assert client.model == _DEFAULT_MODEL + assert client.top_k == 3 + assert client.truncate == RankerTruncateMode.END def test_init_fail_wo_api_key(self, monkeypatch): monkeypatch.delenv("NVIDIA_API_KEY", raising=False) @@ -43,7 +43,7 @@ def test_init_fail_wo_api_key(self, monkeypatch): def test_init_pass_wo_api_key_w_api_url(self): url = "https://url.bogus/v1" client = NvidiaRanker(api_url=url) - assert client._api_url == url + assert client.api_url == url def test_warm_up_required(self): client = NvidiaRanker() @@ -271,6 +271,11 @@ def test_to_dict(self) -> None: "truncate": None, "api_url": None, "api_key": {"type": "env_var", "env_vars": ["NVIDIA_API_KEY"], "strict": True}, + "query_prefix": "", + "document_prefix": "", + "meta_fields_to_embed": [], + "embedding_separator": "\n", + "timeout": 60.0, }, } @@ -284,14 +289,24 @@ def test_from_dict(self) -> None: "truncate": None, "api_url": None, "api_key": {"type": "env_var", "env_vars": ["NVIDIA_API_KEY"], "strict": True}, + "query_prefix": "", + "document_prefix": "", + "meta_fields_to_embed": [], + "embedding_separator": "\n", + "timeout": 45.0, }, } ) - assert client._model == "nvidia/nv-rerankqa-mistral-4b-v3" - assert client._top_k == 5 - assert client._truncate is None - assert client._api_url is None - assert client._api_key == Secret.from_env_var("NVIDIA_API_KEY") + assert client.model == "nvidia/nv-rerankqa-mistral-4b-v3" + assert client.top_k == 5 + assert client.truncate is None + assert client.api_url is None + assert client.api_key == Secret.from_env_var("NVIDIA_API_KEY") + assert client.query_prefix == "" + assert client.document_prefix == "" + assert client.meta_fields_to_embed == [] + assert client.embedding_separator == "\n" + assert client.timeout == 45.0 def test_from_dict_defaults(self) -> None: client = NvidiaRanker.from_dict( @@ -300,8 +315,49 @@ def test_from_dict_defaults(self) -> None: "init_parameters": {}, } ) - assert client._model == "nvidia/nv-rerankqa-mistral-4b-v3" - assert client._top_k == 5 - assert client._truncate is None - assert client._api_url is None - assert client._api_key == Secret.from_env_var("NVIDIA_API_KEY") + assert client.model == "nvidia/nv-rerankqa-mistral-4b-v3" + assert client.top_k == 5 + assert client.truncate is None + assert client.api_url is None + assert client.api_key == Secret.from_env_var("NVIDIA_API_KEY") + assert client.query_prefix == "" + assert client.document_prefix == "" + assert client.meta_fields_to_embed == [] + assert client.embedding_separator == "\n" + assert client.timeout == 60.0 + + def test_setting_timeout(self, monkeypatch): + monkeypatch.setenv("NVIDIA_API_KEY", "fake-api-key") + client = NvidiaRanker(timeout=10.0) + client.warm_up() + assert client._backend.timeout == 10.0 + + def test_setting_timeout_env(self, monkeypatch): + monkeypatch.setenv("NVIDIA_API_KEY", "fake-api-key") + monkeypatch.setenv("NVIDIA_TIMEOUT", "45") + client = NvidiaRanker() + client.warm_up() + assert client._backend.timeout == 45.0 + + def test_prepare_texts_to_embed_w_metadata(self): + documents = [ + Document(content=f"document number {i}:\ncontent", meta={"meta_field": f"meta_value {i}"}) for i in range(5) + ] + + ranker = NvidiaRanker( + model=None, + api_key=Secret.from_token("fake-api-key"), + meta_fields_to_embed=["meta_field"], + embedding_separator=" | ", + ) + + prepared_texts = ranker._prepare_documents_to_embed(documents) + + # note that newline is replaced by space + assert prepared_texts == [ + "meta_value 0 | document number 0:\ncontent", + "meta_value 1 | document number 1:\ncontent", + "meta_value 2 | document number 2:\ncontent", + "meta_value 3 | document number 3:\ncontent", + "meta_value 4 | document number 4:\ncontent", + ] diff --git a/integrations/nvidia/tests/test_text_embedder.py b/integrations/nvidia/tests/test_text_embedder.py index 278fa5191..b572cc046 100644 --- a/integrations/nvidia/tests/test_text_embedder.py +++ b/integrations/nvidia/tests/test_text_embedder.py @@ -56,6 +56,7 @@ def test_to_dict(self, monkeypatch): "prefix": "", "suffix": "", "truncate": None, + "timeout": 60.0, }, } @@ -67,6 +68,7 @@ def test_to_dict_with_custom_init_parameters(self, monkeypatch): prefix="prefix", suffix="suffix", truncate=EmbeddingTruncateMode.START, + timeout=10.0, ) data = component.to_dict() assert data == { @@ -78,6 +80,7 @@ def test_to_dict_with_custom_init_parameters(self, monkeypatch): "prefix": "prefix", "suffix": "suffix", "truncate": "START", + "timeout": 10.0, }, } @@ -92,6 +95,7 @@ def test_from_dict(self, monkeypatch): "prefix": "prefix", "suffix": "suffix", "truncate": "START", + "timeout": 10.0, }, } component = NvidiaTextEmbedder.from_dict(data) @@ -100,6 +104,7 @@ def test_from_dict(self, monkeypatch): assert component.prefix == "prefix" assert component.suffix == "suffix" assert component.truncate == EmbeddingTruncateMode.START + assert component.timeout == 10.0 def test_from_dict_defaults(self, monkeypatch): monkeypatch.setenv("NVIDIA_API_KEY", "fake-api-key") @@ -175,6 +180,19 @@ def test_run_empty_string(self): with pytest.raises(ValueError, match="empty string"): embedder.run(text="") + def test_setting_timeout(self, monkeypatch): + monkeypatch.setenv("NVIDIA_API_KEY", "fake-api-key") + embedder = NvidiaTextEmbedder(timeout=10.0) + embedder.warm_up() + assert embedder.backend.timeout == 10.0 + + def test_setting_timeout_env(self, monkeypatch): + monkeypatch.setenv("NVIDIA_API_KEY", "fake-api-key") + monkeypatch.setenv("NVIDIA_TIMEOUT", "45") + embedder = NvidiaTextEmbedder() + embedder.warm_up() + assert embedder.backend.timeout == 45.0 + @pytest.mark.skipif( not os.environ.get("NVIDIA_NIM_EMBEDDER_MODEL", None) or not os.environ.get("NVIDIA_NIM_ENDPOINT_URL", None), reason="Export an env var called NVIDIA_NIM_EMBEDDER_MODEL containing the hosted model name and " From 028e2c387e228aea8549e0c585d2315cb945f3a6 Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Thu, 21 Nov 2024 13:19:41 +0100 Subject: [PATCH 094/229] =?UTF-8?q?fix:=20AmazonBedrockChatGenerator=20wit?= =?UTF-8?q?h=20Claude=20raises=20moot=20warning=20for=20stream=E2=80=A6=20?= =?UTF-8?q?(#1205)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * AmazonBedrockChatGenerator with Claude raises moot warning for stream kwarg * Retire meta.llama2-13b-chat-v1 from tests * AmazonBedrockChatGenerator with Mistral raises moot warning for stream kwarg --- .../components/generators/amazon_bedrock/chat/adapters.py | 6 ++++++ integrations/amazon_bedrock/tests/test_chat_generator.py | 2 +- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/adapters.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/adapters.py index f5e8f8181..cbb5ee370 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/adapters.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/adapters.py @@ -212,6 +212,8 @@ def prepare_body(self, messages: List[ChatMessage], **inference_kwargs) -> Dict[ stop_sequences = inference_kwargs.get("stop_sequences", []) + inference_kwargs.pop("stop_words", []) if stop_sequences: inference_kwargs["stop_sequences"] = stop_sequences + # pop stream kwarg from inference_kwargs as Anthropic does not support it (if provided) + inference_kwargs.pop("stream", None) params = self._get_params(inference_kwargs, default_params, self.ALLOWED_PARAMS) body = {**self.prepare_chat_messages(messages=messages), **params} return body @@ -384,6 +386,10 @@ def prepare_body(self, messages: List[ChatMessage], **inference_kwargs) -> Dict[ stop_words = inference_kwargs.pop("stop_words", []) if stop_words: inference_kwargs["stop"] = stop_words + + # pop stream kwarg from inference_kwargs as Mistral does not support it (if provided) + inference_kwargs.pop("stream", None) + params = self._get_params(inference_kwargs, default_params, self.ALLOWED_PARAMS) body = {"prompt": self.prepare_chat_messages(messages=messages), **params} return body diff --git a/integrations/amazon_bedrock/tests/test_chat_generator.py b/integrations/amazon_bedrock/tests/test_chat_generator.py index 571e03eb2..185a34c8a 100644 --- a/integrations/amazon_bedrock/tests/test_chat_generator.py +++ b/integrations/amazon_bedrock/tests/test_chat_generator.py @@ -17,7 +17,7 @@ ) KLASS = "haystack_integrations.components.generators.amazon_bedrock.chat.chat_generator.AmazonBedrockChatGenerator" -MODELS_TO_TEST = ["anthropic.claude-3-sonnet-20240229-v1:0", "anthropic.claude-v2:1", "meta.llama2-13b-chat-v1"] +MODELS_TO_TEST = ["anthropic.claude-3-sonnet-20240229-v1:0", "anthropic.claude-v2:1"] MODELS_TO_TEST_WITH_TOOLS = ["anthropic.claude-3-haiku-20240307-v1:0"] MISTRAL_MODELS = [ "mistral.mistral-7b-instruct-v0:2", From c28b83451bd70369b5503e7349023e4d8b314809 Mon Sep 17 00:00:00 2001 From: HaystackBot Date: Thu, 21 Nov 2024 15:22:22 +0000 Subject: [PATCH 095/229] Update the changelog --- integrations/anthropic/CHANGELOG.md | 40 +++++++++++++++++++++++++++-- 1 file changed, 38 insertions(+), 2 deletions(-) diff --git a/integrations/anthropic/CHANGELOG.md b/integrations/anthropic/CHANGELOG.md index a219da6a5..a7cdc7d09 100644 --- a/integrations/anthropic/CHANGELOG.md +++ b/integrations/anthropic/CHANGELOG.md @@ -1,11 +1,29 @@ # Changelog +## [unreleased] + +### ⚙️ CI + +- Adopt uv as installer (#1142) + +### 🧹 Chores + +- Update ruff linting scripts and settings (#1105) + +### 🌀 Miscellaneous + +- Add AnthropicVertexChatGenerator component (#1192) + ## [integrations/anthropic-v1.1.0] - 2024-09-20 ### 🚀 Features - Add Anthropic prompt caching support, add example (#1006) +### 🌀 Miscellaneous + +- Chore: Update Anthropic example, use ChatPromptBuilder properly (#978) + ## [integrations/anthropic-v1.0.0] - 2024-08-12 ### 🐛 Bug Fixes @@ -20,12 +38,18 @@ - Do not retry tests in `hatch run test` command (#954) + ## [integrations/anthropic-v0.4.1] - 2024-07-17 -### ⚙️ Miscellaneous Tasks +### 🧹 Chores - Update ruff invocation to include check parameter (#853) +### 🌀 Miscellaneous + +- Ci: install `pytest-rerunfailures` where needed; add retry config to `test-cov` script (#845) +- Add meta deprecration warning (#910) + ## [integrations/anthropic-v0.4.0] - 2024-06-21 ### 🚀 Features @@ -33,12 +57,24 @@ - Update Anthropic/Cohere for tools use (#790) - Update Anthropic default models, pydocs (#839) -### ⚙️ Miscellaneous Tasks +### ⚙️ CI - Retry tests to reduce flakyness (#836) +### 🌀 Miscellaneous + +- Remove references to Python 3.7 (#601) +- Chore: add license classifiers (#680) +- Chore: change the pydoc renderer class (#718) +- Docs: add missing api references (#728) + ## [integrations/anthropic-v0.2.0] - 2024-03-15 +### 🌀 Miscellaneous + +- Docs: Replace amazon-bedrock with anthropic in readme (#584) +- Chore: Use the correct sonnet model name (#587) + ## [integrations/anthropic-v0.1.0] - 2024-03-15 ### 🚀 Features From cb107c375f7a80d34981fe52de8a5ac01ca15a70 Mon Sep 17 00:00:00 2001 From: Amna Mubashar Date: Thu, 21 Nov 2024 16:24:25 +0100 Subject: [PATCH 096/229] Fix generated tag name for version release(#1206) --- integrations/azure_ai_search/pyproject.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/integrations/azure_ai_search/pyproject.toml b/integrations/azure_ai_search/pyproject.toml index 49ca623e7..cb967b1e0 100644 --- a/integrations/azure_ai_search/pyproject.toml +++ b/integrations/azure_ai_search/pyproject.toml @@ -33,11 +33,11 @@ packages = ["src/haystack_integrations"] [tool.hatch.version] source = "vcs" -tag-pattern = 'integrations\/azure-ai-search-v(?P.*)' +tag-pattern = 'integrations\/azure_ai_search-v(?P.*)' [tool.hatch.version.raw-options] root = "../.." -git_describe_command = 'git describe --tags --match="integrations/azure-ai-search-v[0-9]*"' +git_describe_command = 'git describe --tags --match="integrations/azure_ai_search-v[0-9]*"' [tool.hatch.envs.default] dependencies = [ From 872dd84aeb054c7deaa031152a636ec5ce7bcd5c Mon Sep 17 00:00:00 2001 From: HaystackBot Date: Thu, 21 Nov 2024 15:28:23 +0000 Subject: [PATCH 097/229] Update the changelog --- integrations/azure_ai_search/CHANGELOG.md | 15 +++++++++++++++ 1 file changed, 15 insertions(+) create mode 100644 integrations/azure_ai_search/CHANGELOG.md diff --git a/integrations/azure_ai_search/CHANGELOG.md b/integrations/azure_ai_search/CHANGELOG.md new file mode 100644 index 000000000..e22559b12 --- /dev/null +++ b/integrations/azure_ai_search/CHANGELOG.md @@ -0,0 +1,15 @@ +# Changelog + +## [integrations/azure_ai_search-v0.1.0] - 2024-11-21 + +### 🚀 Features + +- Add Azure AI Search integration (#1122) +- Add BM25 and Hybrid Search Retrievers to Azure AI Search Integration (#1175) + +### 🌀 Miscellaneous + +- Enable kwargs in SearchIndex and Embedding Retriever (#1185) +- Fix: Fix tag name for version release (#1206) + + From f9d0e77ae9f0e850d3e3ee5038ecfd0b77f98222 Mon Sep 17 00:00:00 2001 From: jlonge4 <91354480+jlonge4@users.noreply.github.com> Date: Thu, 21 Nov 2024 12:00:20 -0500 Subject: [PATCH 098/229] feat: add `JinaReaderConnector` (#1150) * begin rough draft * begin rough draft * begin rough draft * small fixes * Haystack document conversion * git folder changes * add pipeline functions * correct mode map * add reader mode Enum class file * add docstrings * add JINA url for ref * add mode norm for mode map check in run method * add mode norm for mode map check in run method * add json_response and associated parsing * ignore api key lint error * ignore api key lint error * reduce code redundancy * reduce code redundancy * add headers option to run method * Update integrations/jina/src/haystack_integrations/components/reader/jina/reader.py Co-authored-by: Stefano Fiorucci * Update integrations/jina/src/haystack_integrations/components/reader/jina/reader.py Co-authored-by: Stefano Fiorucci * Update integrations/jina/src/haystack_integrations/components/reader/jina/reader.py Co-authored-by: Stefano Fiorucci * Update integrations/jina/src/haystack_integrations/components/reader/jina/reader.py Co-authored-by: Stefano Fiorucci * update location / final edits * Update integrations/jina/src/haystack_integrations/components/converters/jina/reader.py Co-authored-by: Stefano Fiorucci * update paths * add descriptions for json response/headers * lint * unit tests for reader-connector * unit tests for reader-connector * unit tests for reader-connector * fix circular import * update header test * update test * update test * update test * update test * update test * update test * update test * update test * refactoring + more tests * example * pydoc config * examples can contain print --------- Co-authored-by: anitha6g Co-authored-by: Stefano Fiorucci --- .../jina/examples/jina_reader_connector.py | 47 ++++++ integrations/jina/pydoc/config.yml | 1 + integrations/jina/pyproject.toml | 7 +- .../components/connectors/jina/__init__.py | 7 + .../components/connectors/jina/reader.py | 141 ++++++++++++++++++ .../components/connectors/jina/reader_mode.py | 40 +++++ .../jina/tests/test_reader_connector.py | 141 ++++++++++++++++++ 7 files changed, 383 insertions(+), 1 deletion(-) create mode 100644 integrations/jina/examples/jina_reader_connector.py create mode 100644 integrations/jina/src/haystack_integrations/components/connectors/jina/__init__.py create mode 100644 integrations/jina/src/haystack_integrations/components/connectors/jina/reader.py create mode 100644 integrations/jina/src/haystack_integrations/components/connectors/jina/reader_mode.py create mode 100644 integrations/jina/tests/test_reader_connector.py diff --git a/integrations/jina/examples/jina_reader_connector.py b/integrations/jina/examples/jina_reader_connector.py new file mode 100644 index 000000000..24b6f5db3 --- /dev/null +++ b/integrations/jina/examples/jina_reader_connector.py @@ -0,0 +1,47 @@ +# to make use of the JinaReaderConnector, we first need to install the Haystack integration +# pip install jina-haystack + +# then we must set the JINA_API_KEY environment variable +# export JINA_API_KEY= + + +from haystack_integrations.components.connectors.jina import JinaReaderConnector + +# we can use the JinaReaderConnector to process a URL and return the textual content of the page +reader = JinaReaderConnector(mode="read") +query = "https://example.com" +result = reader.run(query=query) + +print(result) +# {'documents': [Document(id=fa3e51e4ca91828086dca4f359b6e1ea2881e358f83b41b53c84616cb0b2f7cf, +# content: 'This domain is for use in illustrative examples in documents. You may use this domain in literature ...', +# meta: {'title': 'Example Domain', 'description': '', 'url': 'https://example.com/', 'usage': {'tokens': 42}})]} + + +# we can perform a web search by setting the mode to "search" +reader = JinaReaderConnector(mode="search") +query = "UEFA Champions League 2024" +result = reader.run(query=query) + +print(result) +# {'documents': Document(id=6a71abf9955594232037321a476d39a835c0cb7bc575d886ee0087c973c95940, +# content: '2024/25 UEFA Champions League: Matches, draw, final, key dates | UEFA Champions League | UEFA.com...', +# meta: {'title': '2024/25 UEFA Champions League: Matches, draw, final, key dates', +# 'description': 'What are the match dates? Where is the 2025 final? How will the competition work?', +# 'url': 'https://www.uefa.com/uefachampionsleague/news/...', +# 'usage': {'tokens': 5581}}), ...]} + + +# finally, we can perform fact-checking by setting the mode to "ground" (experimental) +reader = JinaReaderConnector(mode="ground") +query = "ChatGPT was launched in 2017" +result = reader.run(query=query) + +print(result) +# {'documents': [Document(id=f0c964dbc1ebb2d6584c8032b657150b9aa6e421f714cc1b9f8093a159127f0c, +# content: 'The statement that ChatGPT was launched in 2017 is incorrect. Multiple references confirm that ChatG...', +# meta: {'factuality': 0, 'result': False, 'references': [ +# {'url': 'https://en.wikipedia.org/wiki/ChatGPT', +# 'keyQuote': 'ChatGPT is a generative artificial intelligence (AI) chatbot developed by OpenAI and launched in 2022.', +# 'isSupportive': False}, ...], +# 'usage': {'tokens': 10188}})]} diff --git a/integrations/jina/pydoc/config.yml b/integrations/jina/pydoc/config.yml index 8c7a241f6..2d0ef4f87 100644 --- a/integrations/jina/pydoc/config.yml +++ b/integrations/jina/pydoc/config.yml @@ -6,6 +6,7 @@ loaders: "haystack_integrations.components.embedders.jina.document_embedder", "haystack_integrations.components.embedders.jina.text_embedder", "haystack_integrations.components.rankers.jina.ranker", + "haystack_integrations.components.connectors.jina.reader", ] ignore_when_discovered: ["__init__"] processors: diff --git a/integrations/jina/pyproject.toml b/integrations/jina/pyproject.toml index c89eeacb4..e3af086d0 100644 --- a/integrations/jina/pyproject.toml +++ b/integrations/jina/pyproject.toml @@ -132,18 +132,23 @@ ban-relative-imports = "parents" [tool.ruff.lint.per-file-ignores] # Tests can use magic values, assertions, and relative imports "tests/**/*" = ["PLR2004", "S101", "TID252"] +# examples can contain "print" commands +"examples/**/*" = ["T201"] [tool.coverage.run] source = ["haystack_integrations"] branch = true parallel = false - [tool.coverage.report] omit = ["*/tests/*", "*/__init__.py"] show_missing = true exclude_lines = ["no cov", "if __name__ == .__main__.:", "if TYPE_CHECKING:"] +[tool.pytest.ini_options] +minversion = "6.0" +markers = ["unit: unit tests", "integration: integration tests"] + [[tool.mypy.overrides]] module = ["haystack.*", "haystack_integrations.*", "pytest.*"] ignore_missing_imports = true diff --git a/integrations/jina/src/haystack_integrations/components/connectors/jina/__init__.py b/integrations/jina/src/haystack_integrations/components/connectors/jina/__init__.py new file mode 100644 index 000000000..95368df21 --- /dev/null +++ b/integrations/jina/src/haystack_integrations/components/connectors/jina/__init__.py @@ -0,0 +1,7 @@ +# SPDX-FileCopyrightText: 2023-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 +from .reader import JinaReaderConnector +from .reader_mode import JinaReaderMode + +__all__ = ["JinaReaderConnector", "JinaReaderMode"] diff --git a/integrations/jina/src/haystack_integrations/components/connectors/jina/reader.py b/integrations/jina/src/haystack_integrations/components/connectors/jina/reader.py new file mode 100644 index 000000000..eb53329f7 --- /dev/null +++ b/integrations/jina/src/haystack_integrations/components/connectors/jina/reader.py @@ -0,0 +1,141 @@ +# SPDX-FileCopyrightText: 2023-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +import json +from typing import Any, Dict, List, Optional, Union +from urllib.parse import quote + +import requests +from haystack import Document, component, default_from_dict, default_to_dict +from haystack.utils import Secret, deserialize_secrets_inplace + +from .reader_mode import JinaReaderMode + +READER_ENDPOINT_URL_BY_MODE = { + JinaReaderMode.READ: "https://r.jina.ai/", + JinaReaderMode.SEARCH: "https://s.jina.ai/", + JinaReaderMode.GROUND: "https://g.jina.ai/", +} + + +@component +class JinaReaderConnector: + """ + A component that interacts with Jina AI's reader service to process queries and return documents. + + This component supports different modes of operation: `read`, `search`, and `ground`. + + Usage example: + ```python + from haystack_integrations.components.connectors.jina import JinaReaderConnector + + reader = JinaReaderConnector(mode="read") + query = "https://example.com" + result = reader.run(query=query) + document = result["documents"][0] + print(document.content) + + >>> "This domain is for use in illustrative examples..." + ``` + """ + + def __init__( + self, + mode: Union[JinaReaderMode, str], + api_key: Secret = Secret.from_env_var("JINA_API_KEY"), # noqa: B008 + json_response: bool = True, + ): + """ + Initialize a JinaReader instance. + + :param mode: The operation mode for the reader (`read`, `search` or `ground`). + - `read`: process a URL and return the textual content of the page. + - `search`: search the web and return textual content of the most relevant pages. + - `ground`: call the grounding engine to perform fact checking. + For more information on the modes, see the [Jina Reader documentation](https://jina.ai/reader/). + :param api_key: The Jina API key. It can be explicitly provided or automatically read from the + environment variable JINA_API_KEY (recommended). + :param json_response: Controls the response format from the Jina Reader API. + If `True`, requests a JSON response, resulting in Documents with rich structured metadata. + If `False`, requests a raw response, resulting in one Document with minimal metadata. + """ + self.api_key = api_key + self.json_response = json_response + + if isinstance(mode, str): + mode = JinaReaderMode.from_str(mode) + self.mode = mode + + def to_dict(self) -> Dict[str, Any]: + """ + Serializes the component to a dictionary. + :returns: + Dictionary with serialized data. + """ + return default_to_dict( + self, + api_key=self.api_key.to_dict(), + mode=str(self.mode), + json_response=self.json_response, + ) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "JinaReaderConnector": + """ + Deserializes the component from a dictionary. + :param data: + Dictionary to deserialize from. + :returns: + Deserialized component. + """ + deserialize_secrets_inplace(data["init_parameters"], keys=["api_key"]) + return default_from_dict(cls, data) + + def _json_to_document(self, data: dict) -> Document: + """ + Convert a JSON response/record to a Document, depending on the reader mode. + """ + if self.mode == JinaReaderMode.GROUND: + content = data.pop("reason") + else: + content = data.pop("content") + document = Document(content=content, meta=data) + return document + + @component.output_types(document=List[Document]) + def run(self, query: str, headers: Optional[Dict[str, str]] = None): + """ + Process the query/URL using the Jina AI reader service. + + :param query: The query string or URL to process. + :param headers: Optional headers to include in the request for customization. Refer to the + [Jina Reader documentation](https://jina.ai/reader/) for more information. + + :returns: + A dictionary with the following keys: + - `documents`: A list of `Document` objects. + """ + headers = headers or {} + headers["Authorization"] = f"Bearer {self.api_key.resolve_value()}" + + if self.json_response: + headers["Accept"] = "application/json" + + endpoint_url = READER_ENDPOINT_URL_BY_MODE[self.mode] + encoded_target = quote(query, safe="") + url = f"{endpoint_url}{encoded_target}" + + response = requests.get(url, headers=headers, timeout=60) + + # raw response: we just return a single Document with text + if not self.json_response: + meta = {"content_type": response.headers["Content-Type"], "query": query} + return {"documents": [Document(content=response.content, meta=meta)]} + + response_json = json.loads(response.content).get("data", {}) + if self.mode == JinaReaderMode.SEARCH: + documents = [self._json_to_document(record) for record in response_json] + return {"documents": documents} + + return {"documents": [self._json_to_document(response_json)]} diff --git a/integrations/jina/src/haystack_integrations/components/connectors/jina/reader_mode.py b/integrations/jina/src/haystack_integrations/components/connectors/jina/reader_mode.py new file mode 100644 index 000000000..2ccf7250b --- /dev/null +++ b/integrations/jina/src/haystack_integrations/components/connectors/jina/reader_mode.py @@ -0,0 +1,40 @@ +# SPDX-FileCopyrightText: 2023-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 +from enum import Enum + + +class JinaReaderMode(Enum): + """ + Enum representing modes for the Jina Reader. + + Modes: + READ: Process a URL and return the textual content of the page. + SEARCH: Search the web and return the textual content of the most relevant pages. + GROUND: Call the grounding engine to perform fact checking. + + """ + + READ = "read" + SEARCH = "search" + GROUND = "ground" + + def __str__(self): + return self.value + + @classmethod + def from_str(cls, string: str) -> "JinaReaderMode": + """ + Create the reader mode from a string. + + :param string: + String to convert. + :returns: + Reader mode. + """ + enum_map = {e.value: e for e in JinaReaderMode} + reader_mode = enum_map.get(string) + if reader_mode is None: + msg = f"Unknown reader mode '{string}'. Supported modes are: {list(enum_map.keys())}" + raise ValueError(msg) + return reader_mode diff --git a/integrations/jina/tests/test_reader_connector.py b/integrations/jina/tests/test_reader_connector.py new file mode 100644 index 000000000..449f73df8 --- /dev/null +++ b/integrations/jina/tests/test_reader_connector.py @@ -0,0 +1,141 @@ +import json +import os +from unittest.mock import patch + +import pytest +from haystack import Document +from haystack.utils import Secret + +from haystack_integrations.components.connectors.jina import JinaReaderConnector, JinaReaderMode + + +class TestJinaReaderConnector: + def test_init_with_custom_parameters(self, monkeypatch): + monkeypatch.setenv("TEST_KEY", "test-api-key") + reader = JinaReaderConnector(mode="read", api_key=Secret.from_env_var("TEST_KEY"), json_response=False) + + assert reader.mode == JinaReaderMode.READ + assert reader.api_key.resolve_value() == "test-api-key" + assert reader.json_response is False + + def test_init_with_invalid_mode(self): + with pytest.raises(ValueError): + JinaReaderConnector(mode="INVALID") + + def test_to_dict(self, monkeypatch): + monkeypatch.setenv("TEST_KEY", "test-api-key") + reader = JinaReaderConnector(mode="search", api_key=Secret.from_env_var("TEST_KEY"), json_response=True) + + serialized = reader.to_dict() + + assert serialized["type"] == "haystack_integrations.components.connectors.jina.reader.JinaReaderConnector" + assert "init_parameters" in serialized + + init_params = serialized["init_parameters"] + assert init_params["mode"] == "search" + assert init_params["json_response"] is True + assert "api_key" in init_params + assert init_params["api_key"]["type"] == "env_var" + + def test_from_dict(self, monkeypatch): + monkeypatch.setenv("JINA_API_KEY", "test-api-key") + component_dict = { + "type": "haystack_integrations.components.connectors.jina.reader.JinaReaderConnector", + "init_parameters": { + "api_key": {"type": "env_var", "env_vars": ["JINA_API_KEY"], "strict": True}, + "mode": "read", + "json_response": True, + }, + } + + reader = JinaReaderConnector.from_dict(component_dict) + + assert isinstance(reader, JinaReaderConnector) + assert reader.mode == JinaReaderMode.READ + assert reader.json_response is True + assert reader.api_key.resolve_value() == "test-api-key" + + def test_json_to_document_read_mode(self, monkeypatch): + monkeypatch.setenv("TEST_KEY", "test-api-key") + reader = JinaReaderConnector(mode="read") + + data = {"content": "Mocked content", "title": "Mocked Title", "url": "https://example.com"} + document = reader._json_to_document(data) + + assert isinstance(document, Document) + assert document.content == "Mocked content" + assert document.meta["title"] == "Mocked Title" + assert document.meta["url"] == "https://example.com" + + def test_json_to_document_ground_mode(self, monkeypatch): + monkeypatch.setenv("TEST_KEY", "test-api-key") + reader = JinaReaderConnector(mode="ground") + + data = { + "factuality": 0, + "result": False, + "reason": "The statement is contradicted by...", + "references": [{"url": "https://example.com", "keyQuote": "Mocked key quote", "isSupportive": False}], + } + + document = reader._json_to_document(data) + assert isinstance(document, Document) + assert document.content == "The statement is contradicted by..." + assert document.meta["factuality"] == 0 + assert document.meta["result"] is False + assert document.meta["references"] == [ + {"url": "https://example.com", "keyQuote": "Mocked key quote", "isSupportive": False} + ] + + @patch("requests.get") + def test_run_with_mocked_response(self, mock_get, monkeypatch): + monkeypatch.setenv("JINA_API_KEY", "test-api-key") + mock_json_response = { + "data": {"content": "Mocked content", "title": "Mocked Title", "url": "https://example.com"} + } + mock_get.return_value.content = json.dumps(mock_json_response).encode("utf-8") + mock_get.return_value.headers = {"Content-Type": "application/json"} + + reader = JinaReaderConnector(mode="read") + result = reader.run(query="https://example.com") + + assert mock_get.call_count == 1 + assert mock_get.call_args[0][0] == "https://r.jina.ai/https%3A%2F%2Fexample.com" + assert mock_get.call_args[1]["headers"] == { + "Authorization": "Bearer test-api-key", + "Accept": "application/json", + } + + assert len(result) == 1 + document = result["documents"][0] + assert isinstance(document, Document) + assert document.content == "Mocked content" + assert document.meta["title"] == "Mocked Title" + assert document.meta["url"] == "https://example.com" + + @pytest.mark.skipif(not os.environ.get("JINA_API_KEY", None), reason="JINA_API_KEY env var not set") + @pytest.mark.integration + def test_run_reader_mode(self): + reader = JinaReaderConnector(mode="read") + result = reader.run(query="https://example.com") + + assert len(result) == 1 + document = result["documents"][0] + assert isinstance(document, Document) + assert "This domain is for use in illustrative examples" in document.content + assert document.meta["title"] == "Example Domain" + assert document.meta["url"] == "https://example.com/" + + @pytest.mark.skipif(not os.environ.get("JINA_API_KEY", None), reason="JINA_API_KEY env var not set") + @pytest.mark.integration + def test_run_search_mode(self): + reader = JinaReaderConnector(mode="search") + result = reader.run(query="When was Jina AI founded?") + + assert len(result) >= 1 + for doc in result["documents"]: + assert isinstance(doc, Document) + assert doc.content + assert "title" in doc.meta + assert "url" in doc.meta + assert "description" in doc.meta From 6db7399a55099ccc3a4f96169bcd5fc46711bdf9 Mon Sep 17 00:00:00 2001 From: HaystackBot Date: Thu, 21 Nov 2024 17:02:12 +0000 Subject: [PATCH 099/229] Update the changelog --- integrations/jina/CHANGELOG.md | 56 +++++++++++++++++++++++++++++----- 1 file changed, 48 insertions(+), 8 deletions(-) diff --git a/integrations/jina/CHANGELOG.md b/integrations/jina/CHANGELOG.md index 918a764f0..f65853d31 100644 --- a/integrations/jina/CHANGELOG.md +++ b/integrations/jina/CHANGELOG.md @@ -1,17 +1,49 @@ # Changelog +## [integrations/jina-v0.5.0] - 2024-11-21 + +### 🚀 Features + +- Add `JinaReaderConnector` (#1150) + +### 📚 Documentation + +- Update docstrings of JinaDocumentEmbedder and JinaTextEmbedder (#1092) + +### ⚙️ CI + +- Adopt uv as installer (#1142) + +### 🧹 Chores + +- Update ruff linting scripts and settings (#1105) + + ## [integrations/jina-v0.4.0] - 2024-09-18 ### 🧪 Testing - Do not retry tests in `hatch run test` command (#954) -### ⚙️ Miscellaneous Tasks +### ⚙️ CI - Retry tests to reduce flakyness (#836) + +### 🧹 Chores + - Update ruff invocation to include check parameter (#853) - Update Jina Embedder usage for V3 release (#1077) +### 🌀 Miscellaneous + +- Remove references to Python 3.7 (#601) +- Jina - add missing ranker to API reference (#610) +- Jina ranker: fix wrong URL in docstring (#628) +- Chore: add license classifiers (#680) +- Chore: change the pydoc renderer class (#718) +- Ci: install `pytest-rerunfailures` where needed; add retry config to `test-cov` script (#845) +- Chore: Jina - ruff update, don't ruff tests (#982) + ## [integrations/jina-v0.3.0] - 2024-03-19 ### 🚀 Features @@ -22,13 +54,17 @@ - Fix order of API docs (#447) -This PR will also push the docs to Readme - ### 📚 Documentation - Update category slug (#442) - Disable-class-def (#556) +### 🌀 Miscellaneous + +- Jina - remove dead code (#422) +- Jina - review docstrings (#504) +- Make tests show coverage (#566) + ## [integrations/jina-v0.2.0] - 2024-02-14 ### 🚀 Features @@ -39,7 +75,7 @@ This PR will also push the docs to Readme - Update paths and titles (#397) -### Jina +### 🌀 Miscellaneous - Update secrets management (#411) @@ -47,18 +83,22 @@ This PR will also push the docs to Readme ### 🐛 Bug Fixes -- Fix project urls (#96) - - +- Fix project URLs (#96) ### 🚜 Refactor - Use `hatch_vcs` to manage integrations versioning (#103) -### ⚙️ Miscellaneous Tasks +### 🧹 Chores - [**breaking**] Rename model_name to model in the Jina integration (#230) +### 🌀 Miscellaneous + +- Change metadata to meta (#152) +- Optimize API key reading (#162) +- Refact!:change import paths (#254) + ## [integrations/jina-v0.0.1] - 2023-12-11 ### 🚀 Features From a97966734054c79b96e9662ffe3dd5dec1736d03 Mon Sep 17 00:00:00 2001 From: Amna Mubashar Date: Fri, 22 Nov 2024 10:46:28 +0100 Subject: [PATCH 100/229] fix: Fix error in README file (#1207) * Remove incorrect colab link --- integrations/azure_ai_search/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/integrations/azure_ai_search/README.md b/integrations/azure_ai_search/README.md index 915a23b63..51cc7720c 100644 --- a/integrations/azure_ai_search/README.md +++ b/integrations/azure_ai_search/README.md @@ -19,7 +19,7 @@ pip install azure-ai-search-haystack ``` ## Examples -You can find a code example showing how to use the Document Store and the Retriever in the documentation or in [this Colab](https://colab.research.google.com/drive/1YpDetI8BRbObPDEVdfqUcwhEX9UUXP-m?usp=sharing). +Refer to the documentation for code examples on utilizing the Document Store and its associated Retrievers. For more usage scenarios, check out the [examples](https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/azure_ai_search/example). ## License From 8d491728725e02a21fd9d3eeee426a5831e4ee9b Mon Sep 17 00:00:00 2001 From: HaystackBot Date: Fri, 22 Nov 2024 11:07:25 +0000 Subject: [PATCH 101/229] Update the changelog --- integrations/azure_ai_search/CHANGELOG.md | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/integrations/azure_ai_search/CHANGELOG.md b/integrations/azure_ai_search/CHANGELOG.md index e22559b12..6a8d26c9d 100644 --- a/integrations/azure_ai_search/CHANGELOG.md +++ b/integrations/azure_ai_search/CHANGELOG.md @@ -1,5 +1,12 @@ # Changelog +## [integrations/azure_ai_search-v0.1.1] - 2024-11-22 + +### 🐛 Bug Fixes + +- Fix error in README file (#1207) + + ## [integrations/azure_ai_search-v0.1.0] - 2024-11-21 ### 🚀 Features From a6a78088a5fd53995fd743c918285244b8cdd0e4 Mon Sep 17 00:00:00 2001 From: Stefano Fiorucci Date: Fri, 22 Nov 2024 13:09:30 +0100 Subject: [PATCH 102/229] fix: adapt to Ollama client 0.4.0 (#1209) * adapt to Ollama client 0.4.0 * remove explicit support for python 3.8 * fix linting --- integrations/ollama/pyproject.toml | 5 ++-- .../components/embedders/ollama/__init__.py | 2 +- .../embedders/ollama/document_embedder.py | 4 +-- .../embedders/ollama/text_embedder.py | 2 +- .../components/generators/ollama/__init__.py | 2 +- .../generators/ollama/chat/chat_generator.py | 17 +++++++----- .../components/generators/ollama/generator.py | 16 ++++++------ .../ollama/tests/test_chat_generator.py | 26 +++++++++---------- 8 files changed, 38 insertions(+), 36 deletions(-) diff --git a/integrations/ollama/pyproject.toml b/integrations/ollama/pyproject.toml index 598d1d214..c9fc22f3d 100644 --- a/integrations/ollama/pyproject.toml +++ b/integrations/ollama/pyproject.toml @@ -19,7 +19,6 @@ classifiers = [ "License :: OSI Approved :: Apache Software License", "Development Status :: 4 - Beta", "Programming Language :: Python", - "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", @@ -27,7 +26,7 @@ classifiers = [ "Programming Language :: Python :: Implementation :: CPython", "Programming Language :: Python :: Implementation :: PyPy", ] -dependencies = ["haystack-ai", "ollama"] +dependencies = ["haystack-ai", "ollama>=0.4.0"] [project.urls] Documentation = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/ollama#readme" @@ -63,7 +62,7 @@ cov-retry = ["test-cov-retry", "cov-report"] docs = ["pydoc-markdown pydoc/config.yml"] [[tool.hatch.envs.all.matrix]] -python = ["3.8", "3.9", "3.10", "3.11", "3.12"] +python = ["3.9", "3.10", "3.11", "3.12"] [tool.hatch.envs.lint] diff --git a/integrations/ollama/src/haystack_integrations/components/embedders/ollama/__init__.py b/integrations/ollama/src/haystack_integrations/components/embedders/ollama/__init__.py index 46042a1c9..822b3d0aa 100644 --- a/integrations/ollama/src/haystack_integrations/components/embedders/ollama/__init__.py +++ b/integrations/ollama/src/haystack_integrations/components/embedders/ollama/__init__.py @@ -1,4 +1,4 @@ from .document_embedder import OllamaDocumentEmbedder from .text_embedder import OllamaTextEmbedder -__all__ = ["OllamaTextEmbedder", "OllamaDocumentEmbedder"] +__all__ = ["OllamaDocumentEmbedder", "OllamaTextEmbedder"] diff --git a/integrations/ollama/src/haystack_integrations/components/embedders/ollama/document_embedder.py b/integrations/ollama/src/haystack_integrations/components/embedders/ollama/document_embedder.py index ac8f38f35..2fab6c72f 100644 --- a/integrations/ollama/src/haystack_integrations/components/embedders/ollama/document_embedder.py +++ b/integrations/ollama/src/haystack_integrations/components/embedders/ollama/document_embedder.py @@ -100,7 +100,7 @@ def _embed_batch( range(0, len(texts_to_embed), batch_size), disable=not self.progress_bar, desc="Calculating embeddings" ): batch = texts_to_embed[i] # Single batch only - result = self._client.embeddings(model=self.model, prompt=batch, options=generation_kwargs) + result = self._client.embeddings(model=self.model, prompt=batch, options=generation_kwargs).model_dump() all_embeddings.append(result["embedding"]) meta["model"] = self.model @@ -122,7 +122,7 @@ def run(self, documents: List[Document], generation_kwargs: Optional[Dict[str, A - `documents`: Documents with embedding information attached - `meta`: The metadata collected during the embedding process """ - if not isinstance(documents, list) or documents and not isinstance(documents[0], Document): + if not isinstance(documents, list) or (documents and not isinstance(documents[0], Document)): msg = ( "OllamaDocumentEmbedder expects a list of Documents as input." "In case you want to embed a list of strings, please use the OllamaTextEmbedder." diff --git a/integrations/ollama/src/haystack_integrations/components/embedders/ollama/text_embedder.py b/integrations/ollama/src/haystack_integrations/components/embedders/ollama/text_embedder.py index 7779c6d6e..b08b8bef3 100644 --- a/integrations/ollama/src/haystack_integrations/components/embedders/ollama/text_embedder.py +++ b/integrations/ollama/src/haystack_integrations/components/embedders/ollama/text_embedder.py @@ -62,7 +62,7 @@ def run(self, text: str, generation_kwargs: Optional[Dict[str, Any]] = None): - `embedding`: The computed embeddings - `meta`: The metadata collected during the embedding process """ - result = self._client.embeddings(model=self.model, prompt=text, options=generation_kwargs) + result = self._client.embeddings(model=self.model, prompt=text, options=generation_kwargs).model_dump() result["meta"] = {"model": self.model} return result diff --git a/integrations/ollama/src/haystack_integrations/components/generators/ollama/__init__.py b/integrations/ollama/src/haystack_integrations/components/generators/ollama/__init__.py index 41a02d0ac..24e4d2edb 100644 --- a/integrations/ollama/src/haystack_integrations/components/generators/ollama/__init__.py +++ b/integrations/ollama/src/haystack_integrations/components/generators/ollama/__init__.py @@ -1,4 +1,4 @@ from .chat.chat_generator import OllamaChatGenerator from .generator import OllamaGenerator -__all__ = ["OllamaGenerator", "OllamaChatGenerator"] +__all__ = ["OllamaChatGenerator", "OllamaGenerator"] diff --git a/integrations/ollama/src/haystack_integrations/components/generators/ollama/chat/chat_generator.py b/integrations/ollama/src/haystack_integrations/components/generators/ollama/chat/chat_generator.py index 558fd593e..b1be7a2db 100644 --- a/integrations/ollama/src/haystack_integrations/components/generators/ollama/chat/chat_generator.py +++ b/integrations/ollama/src/haystack_integrations/components/generators/ollama/chat/chat_generator.py @@ -4,7 +4,7 @@ from haystack.dataclasses import ChatMessage, StreamingChunk from haystack.utils.callable_serialization import deserialize_callable, serialize_callable -from ollama import Client +from ollama import ChatResponse, Client @component @@ -111,12 +111,13 @@ def from_dict(cls, data: Dict[str, Any]) -> "OllamaChatGenerator": def _message_to_dict(self, message: ChatMessage) -> Dict[str, str]: return {"role": message.role.value, "content": message.content} - def _build_message_from_ollama_response(self, ollama_response: Dict[str, Any]) -> ChatMessage: + def _build_message_from_ollama_response(self, ollama_response: ChatResponse) -> ChatMessage: """ Converts the non-streaming response from the Ollama API to a ChatMessage. """ - message = ChatMessage.from_assistant(content=ollama_response["message"]["content"]) - message.meta.update({key: value for key, value in ollama_response.items() if key != "message"}) + response_dict = ollama_response.model_dump() + message = ChatMessage.from_assistant(content=response_dict["message"]["content"]) + message.meta.update({key: value for key, value in response_dict.items() if key != "message"}) return message def _convert_to_streaming_response(self, chunks: List[StreamingChunk]) -> Dict[str, List[Any]]: @@ -133,9 +134,11 @@ def _build_chunk(self, chunk_response: Any) -> StreamingChunk: """ Converts the response from the Ollama API to a StreamingChunk. """ - content = chunk_response["message"]["content"] - meta = {key: value for key, value in chunk_response.items() if key != "message"} - meta["role"] = chunk_response["message"]["role"] + chunk_response_dict = chunk_response.model_dump() + + content = chunk_response_dict["message"]["content"] + meta = {key: value for key, value in chunk_response_dict.items() if key != "message"} + meta["role"] = chunk_response_dict["message"]["role"] chunk_message = StreamingChunk(content, meta) return chunk_message diff --git a/integrations/ollama/src/haystack_integrations/components/generators/ollama/generator.py b/integrations/ollama/src/haystack_integrations/components/generators/ollama/generator.py index 058948e8a..dad671c94 100644 --- a/integrations/ollama/src/haystack_integrations/components/generators/ollama/generator.py +++ b/integrations/ollama/src/haystack_integrations/components/generators/ollama/generator.py @@ -4,7 +4,7 @@ from haystack.dataclasses import StreamingChunk from haystack.utils.callable_serialization import deserialize_callable, serialize_callable -from ollama import Client +from ollama import Client, GenerateResponse @component @@ -118,15 +118,14 @@ def from_dict(cls, data: Dict[str, Any]) -> "OllamaGenerator": data["init_parameters"]["streaming_callback"] = deserialize_callable(serialized_callback_handler) return default_from_dict(cls, data) - def _convert_to_response(self, ollama_response: Dict[str, Any]) -> Dict[str, List[Any]]: + def _convert_to_response(self, ollama_response: GenerateResponse) -> Dict[str, List[Any]]: """ Converts a response from the Ollama API to the required Haystack format. """ + reply = ollama_response.response + meta = {key: value for key, value in ollama_response.model_dump().items() if key != "response"} - replies = [ollama_response["response"]] - meta = {key: value for key, value in ollama_response.items() if key != "response"} - - return {"replies": replies, "meta": [meta]} + return {"replies": [reply], "meta": [meta]} def _convert_to_streaming_response(self, chunks: List[StreamingChunk]) -> Dict[str, List[Any]]: """ @@ -154,8 +153,9 @@ def _build_chunk(self, chunk_response: Any) -> StreamingChunk: """ Converts the response from the Ollama API to a StreamingChunk. """ - content = chunk_response["response"] - meta = {key: value for key, value in chunk_response.items() if key != "response"} + chunk_response_dict = chunk_response.model_dump() + content = chunk_response_dict["response"] + meta = {key: value for key, value in chunk_response_dict.items() if key != "response"} chunk_message = StreamingChunk(content, meta) return chunk_message diff --git a/integrations/ollama/tests/test_chat_generator.py b/integrations/ollama/tests/test_chat_generator.py index 5ac9289aa..b2b3fd927 100644 --- a/integrations/ollama/tests/test_chat_generator.py +++ b/integrations/ollama/tests/test_chat_generator.py @@ -4,7 +4,7 @@ import pytest from haystack.components.generators.utils import print_streaming_chunk from haystack.dataclasses import ChatMessage, ChatRole -from ollama._types import ResponseError +from ollama._types import ChatResponse, ResponseError from haystack_integrations.components.generators.ollama import OllamaChatGenerator @@ -86,18 +86,18 @@ def test_from_dict(self): def test_build_message_from_ollama_response(self): model = "some_model" - ollama_response = { - "model": model, - "created_at": "2023-12-12T14:13:43.416799Z", - "message": {"role": "assistant", "content": "Hello! How are you today?"}, - "done": True, - "total_duration": 5191566416, - "load_duration": 2154458, - "prompt_eval_count": 26, - "prompt_eval_duration": 383809000, - "eval_count": 298, - "eval_duration": 4799921000, - } + ollama_response = ChatResponse( + model=model, + created_at="2023-12-12T14:13:43.416799Z", + message={"role": "assistant", "content": "Hello! How are you today?"}, + done=True, + total_duration=5191566416, + load_duration=2154458, + prompt_eval_count=26, + prompt_eval_duration=383809000, + eval_count=298, + eval_duration=4799921000, + ) observed = OllamaChatGenerator(model=model)._build_message_from_ollama_response(ollama_response) From cdd6555adcc1bcd21ab4dcbfc041d955c8d05cbd Mon Sep 17 00:00:00 2001 From: Stefano Fiorucci Date: Fri, 22 Nov 2024 14:50:48 +0100 Subject: [PATCH 103/229] add connector to README (#1208) --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index af83d045d..c7605f0e5 100644 --- a/README.md +++ b/README.md @@ -38,7 +38,7 @@ Please check out our [Contribution Guidelines](CONTRIBUTING.md) for all the deta | [google-ai-haystack](integrations/google_ai/) | Generator | [![PyPI - Version](https://img.shields.io/pypi/v/google-ai-haystack.svg)](https://pypi.org/project/google-ai-haystack) | [![Test / google-ai](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/google_ai.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/google_ai.yml) | | [google-vertex-haystack](integrations/google_vertex/) | Generator | [![PyPI - Version](https://img.shields.io/pypi/v/google-vertex-haystack.svg)](https://pypi.org/project/google-vertex-haystack) | [![Test / google-vertex](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/google_vertex.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/google_vertex.yml) | | [instructor-embedders-haystack](integrations/instructor_embedders/) | Embedder | [![PyPI - Version](https://img.shields.io/pypi/v/instructor-embedders-haystack.svg)](https://pypi.org/project/instructor-embedders-haystack) | [![Test / instructor-embedders](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/instructor_embedders.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/instructor_embedders.yml) | -| [jina-haystack](integrations/jina/) | Embedder, Ranker | [![PyPI - Version](https://img.shields.io/pypi/v/jina-haystack.svg)](https://pypi.org/project/jina-haystack) | [![Test / jina](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/jina.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/jina.yml) | +| [jina-haystack](integrations/jina/) | Connector, Embedder, Ranker | [![PyPI - Version](https://img.shields.io/pypi/v/jina-haystack.svg)](https://pypi.org/project/jina-haystack) | [![Test / jina](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/jina.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/jina.yml) | | [langfuse-haystack](integrations/langfuse/) | Tracer | [![PyPI - Version](https://img.shields.io/pypi/v/langfuse-haystack.svg?color=orange)](https://pypi.org/project/langfuse-haystack) | [![Test / langfuse](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/langfuse.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/langfuse.yml) | | [llama-cpp-haystack](integrations/llama_cpp/) | Generator | [![PyPI - Version](https://img.shields.io/pypi/v/llama-cpp-haystack.svg?color=orange)](https://pypi.org/project/llama-cpp-haystack) | [![Test / llama-cpp](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/llama_cpp.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/llama_cpp.yml) | | [mistral-haystack](integrations/mistral/) | Embedder, Generator | [![PyPI - Version](https://img.shields.io/pypi/v/mistral-haystack.svg)](https://pypi.org/project/mistral-haystack) | [![Test / mistral](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/mistral.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/mistral.yml) | From 5c56a87b531d7769a774538dca2e828770b19731 Mon Sep 17 00:00:00 2001 From: Stefano Fiorucci Date: Fri, 22 Nov 2024 15:45:37 +0100 Subject: [PATCH 104/229] add ollama missing changelog (#1214) --- integrations/ollama/CHANGELOG.md | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/integrations/ollama/CHANGELOG.md b/integrations/ollama/CHANGELOG.md index 55c6aa7b7..29c8dd910 100644 --- a/integrations/ollama/CHANGELOG.md +++ b/integrations/ollama/CHANGELOG.md @@ -1,5 +1,11 @@ # Changelog +## [integrations/ollama-v2.0.0] - 2024-11-22 + +### 🐛 Bug Fixes + +- Adapt to Ollama client 0.4.0 (#1209) + ## [integrations/ollama-v1.1.0] - 2024-10-11 ### 🚀 Features From de32fa3602153fa61311843e50504b231472246d Mon Sep 17 00:00:00 2001 From: Stefano Fiorucci Date: Fri, 22 Nov 2024 16:17:41 +0100 Subject: [PATCH 105/229] chore: improve README release section (#1211) * improve README release section * fix --- README.md | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/README.md b/README.md index c7605f0e5..43a2610ac 100644 --- a/README.md +++ b/README.md @@ -85,3 +85,8 @@ GitHub. The GitHub Actions workflow will take care of the rest. git push --tags origin ``` 3. Wait for the CI to do its magic + +> [!IMPORTANT] +> When releasing a new integration version, always tag a commit that includes the changes for that integration +> (usually the PR merge commit). If you tag a commit that doesn't include changes for the integration being released, +> the generated changelog will be incorrect. From f286bdf6bbc8e8eedc679edff0ef59b1807f2eb3 Mon Sep 17 00:00:00 2001 From: Qunfei Wu Date: Fri, 22 Nov 2024 19:24:38 +0100 Subject: [PATCH 106/229] feat: add `create_extension` parameter to control vector extension creation (#1213) * Create feature called create_extension and updated documentation. * small refinements * update docker image --------- Co-authored-by: anakin87 --- .github/workflows/pgvector.yml | 2 +- .../document_stores/pgvector/document_store.py | 10 +++++++++- integrations/pgvector/tests/test_document_store.py | 4 ++++ integrations/pgvector/tests/test_retrievers.py | 6 ++++++ 4 files changed, 20 insertions(+), 2 deletions(-) diff --git a/.github/workflows/pgvector.yml b/.github/workflows/pgvector.yml index 0fe20e037..ab5c984ed 100644 --- a/.github/workflows/pgvector.yml +++ b/.github/workflows/pgvector.yml @@ -33,7 +33,7 @@ jobs: python-version: ["3.9", "3.10", "3.11"] services: pgvector: - image: ankane/pgvector:latest + image: pgvector/pgvector:pg17 env: POSTGRES_USER: postgres POSTGRES_PASSWORD: postgres diff --git a/integrations/pgvector/src/haystack_integrations/document_stores/pgvector/document_store.py b/integrations/pgvector/src/haystack_integrations/document_stores/pgvector/document_store.py index 6682c2fee..87655a5ec 100644 --- a/integrations/pgvector/src/haystack_integrations/document_stores/pgvector/document_store.py +++ b/integrations/pgvector/src/haystack_integrations/document_stores/pgvector/document_store.py @@ -78,6 +78,7 @@ def __init__( self, *, connection_string: Secret = Secret.from_env_var("PG_CONN_STR"), + create_extension: bool = True, schema_name: str = "public", table_name: str = "haystack_documents", language: str = "english", @@ -102,6 +103,10 @@ def __init__( e.g.: `PG_CONN_STR="host=HOST port=PORT dbname=DBNAME user=USER password=PASSWORD"` See [PostgreSQL Documentation](https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-CONNSTRING) for more details. + :param create_extension: Whether to create the pgvector extension if it doesn't exist. + Set this to `True` (default) to automatically create the extension if it is missing. + Creating the extension may require superuser privileges. + If set to `False`, ensure the extension is already installed; otherwise, an error will be raised. :param schema_name: The name of the schema the table is created in. The schema must already exist. :param table_name: The name of the table to use to store Haystack documents. :param language: The language to be used to parse query and document content in keyword retrieval. @@ -138,6 +143,7 @@ def __init__( """ self.connection_string = connection_string + self.create_extension = create_extension self.table_name = table_name self.schema_name = schema_name self.embedding_dimension = embedding_dimension @@ -194,7 +200,8 @@ def _create_connection(self): conn_str = self.connection_string.resolve_value() or "" connection = connect(conn_str) connection.autocommit = True - connection.execute("CREATE EXTENSION IF NOT EXISTS vector") + if self.create_extension: + connection.execute("CREATE EXTENSION IF NOT EXISTS vector") register_vector(connection) # Note: this must be called before creating the cursors. self._connection = connection @@ -246,6 +253,7 @@ def to_dict(self) -> Dict[str, Any]: return default_to_dict( self, connection_string=self.connection_string.to_dict(), + create_extension=self.create_extension, schema_name=self.schema_name, table_name=self.table_name, embedding_dimension=self.embedding_dimension, diff --git a/integrations/pgvector/tests/test_document_store.py b/integrations/pgvector/tests/test_document_store.py index c6f160f91..baa921137 100644 --- a/integrations/pgvector/tests/test_document_store.py +++ b/integrations/pgvector/tests/test_document_store.py @@ -66,6 +66,7 @@ def test_init(monkeypatch): monkeypatch.setenv("PG_CONN_STR", "some_connection_string") document_store = PgvectorDocumentStore( + create_extension=True, schema_name="my_schema", table_name="my_table", embedding_dimension=512, @@ -79,6 +80,7 @@ def test_init(monkeypatch): keyword_index_name="my_keyword_index", ) + assert document_store.create_extension assert document_store.schema_name == "my_schema" assert document_store.table_name == "my_table" assert document_store.embedding_dimension == 512 @@ -97,6 +99,7 @@ def test_to_dict(monkeypatch): monkeypatch.setenv("PG_CONN_STR", "some_connection_string") document_store = PgvectorDocumentStore( + create_extension=False, table_name="my_table", embedding_dimension=512, vector_function="l2_distance", @@ -113,6 +116,7 @@ def test_to_dict(monkeypatch): "type": "haystack_integrations.document_stores.pgvector.document_store.PgvectorDocumentStore", "init_parameters": { "connection_string": {"env_vars": ["PG_CONN_STR"], "strict": True, "type": "env_var"}, + "create_extension": False, "table_name": "my_table", "schema_name": "public", "embedding_dimension": 512, diff --git a/integrations/pgvector/tests/test_retrievers.py b/integrations/pgvector/tests/test_retrievers.py index 4125c3e3a..11be71ab1 100644 --- a/integrations/pgvector/tests/test_retrievers.py +++ b/integrations/pgvector/tests/test_retrievers.py @@ -50,6 +50,7 @@ def test_to_dict(self, mock_store): "type": "haystack_integrations.document_stores.pgvector.document_store.PgvectorDocumentStore", "init_parameters": { "connection_string": {"env_vars": ["PG_CONN_STR"], "strict": True, "type": "env_var"}, + "create_extension": True, "schema_name": "public", "table_name": "haystack", "embedding_dimension": 768, @@ -82,6 +83,7 @@ def test_from_dict(self, monkeypatch): "type": "haystack_integrations.document_stores.pgvector.document_store.PgvectorDocumentStore", "init_parameters": { "connection_string": {"env_vars": ["PG_CONN_STR"], "strict": True, "type": "env_var"}, + "create_extension": False, "table_name": "haystack_test_to_dict", "embedding_dimension": 768, "vector_function": "cosine_similarity", @@ -106,6 +108,7 @@ def test_from_dict(self, monkeypatch): assert isinstance(document_store, PgvectorDocumentStore) assert isinstance(document_store.connection_string, EnvVarSecret) + assert not document_store.create_extension assert document_store.table_name == "haystack_test_to_dict" assert document_store.embedding_dimension == 768 assert document_store.vector_function == "cosine_similarity" @@ -176,6 +179,7 @@ def test_to_dict(self, mock_store): "type": "haystack_integrations.document_stores.pgvector.document_store.PgvectorDocumentStore", "init_parameters": { "connection_string": {"env_vars": ["PG_CONN_STR"], "strict": True, "type": "env_var"}, + "create_extension": True, "schema_name": "public", "table_name": "haystack", "embedding_dimension": 768, @@ -207,6 +211,7 @@ def test_from_dict(self, monkeypatch): "type": "haystack_integrations.document_stores.pgvector.document_store.PgvectorDocumentStore", "init_parameters": { "connection_string": {"env_vars": ["PG_CONN_STR"], "strict": True, "type": "env_var"}, + "create_extension": False, "table_name": "haystack_test_to_dict", "embedding_dimension": 768, "vector_function": "cosine_similarity", @@ -230,6 +235,7 @@ def test_from_dict(self, monkeypatch): assert isinstance(document_store, PgvectorDocumentStore) assert isinstance(document_store.connection_string, EnvVarSecret) + assert not document_store.create_extension assert document_store.table_name == "haystack_test_to_dict" assert document_store.embedding_dimension == 768 assert document_store.vector_function == "cosine_similarity" From 6909e46735853e4cf329381cc1bdc93a216886e9 Mon Sep 17 00:00:00 2001 From: HaystackBot Date: Fri, 22 Nov 2024 18:26:34 +0000 Subject: [PATCH 107/229] Update the changelog --- integrations/pgvector/CHANGELOG.md | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/integrations/pgvector/CHANGELOG.md b/integrations/pgvector/CHANGELOG.md index 7c8be2340..f3821f1d3 100644 --- a/integrations/pgvector/CHANGELOG.md +++ b/integrations/pgvector/CHANGELOG.md @@ -1,5 +1,12 @@ # Changelog +## [integrations/pgvector-v1.2.0] - 2024-11-22 + +### 🚀 Features + +- Add `create_extension` parameter to control vector extension creation (#1213) + + ## [integrations/pgvector-v1.1.0] - 2024-11-21 ### 🚀 Features From 7b9c8a6cabfe64592ff8acd550cea7f585c50a24 Mon Sep 17 00:00:00 2001 From: Stefano Fiorucci Date: Mon, 25 Nov 2024 09:32:38 +0100 Subject: [PATCH 108/229] chore: fix linting/isort (#1215) * fix linting * vertex fix --- .../components/embedders/amazon_bedrock/__init__.py | 2 +- .../embedders/amazon_bedrock/document_embedder.py | 2 +- .../components/generators/amazon_bedrock/__init__.py | 2 +- .../components/generators/anthropic/__init__.py | 2 +- .../document_stores/azure_ai_search/__init__.py | 2 +- .../components/retrievers/chroma/__init__.py | 2 +- .../components/embedders/cohere/document_embedder.py | 2 +- .../components/generators/cohere/__init__.py | 2 +- .../components/embedders/fastembed/__init__.py | 2 +- .../embedders/fastembed/fastembed_document_embedder.py | 2 +- .../fastembed/fastembed_sparse_document_embedder.py | 2 +- .../components/rankers/fastembed/ranker.py | 2 +- .../components/generators/google_ai/__init__.py | 2 +- .../components/generators/google_vertex/__init__.py | 2 +- integrations/google_vertex/tests/chat/test_gemini.py | 6 +++--- .../instructor_embedders/instructor_document_embedder.py | 2 +- .../components/embedders/jina/document_embedder.py | 2 +- .../components/generators/llama_cpp/__init__.py | 2 +- .../components/embedders/nvidia/__init__.py | 2 +- .../components/embedders/nvidia/document_embedder.py | 2 +- .../src/haystack_integrations/utils/nvidia/__init__.py | 2 +- .../components/embedders/optimum/__init__.py | 4 ++-- .../embedders/optimum/optimum_document_embedder.py | 2 +- .../components/retrievers/qdrant/__init__.py | 2 +- .../document_stores/weaviate/__init__.py | 2 +- 25 files changed, 28 insertions(+), 28 deletions(-) diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/embedders/amazon_bedrock/__init__.py b/integrations/amazon_bedrock/src/haystack_integrations/components/embedders/amazon_bedrock/__init__.py index b2efefdc8..2ebd35979 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/embedders/amazon_bedrock/__init__.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/embedders/amazon_bedrock/__init__.py @@ -4,4 +4,4 @@ from .document_embedder import AmazonBedrockDocumentEmbedder from .text_embedder import AmazonBedrockTextEmbedder -__all__ = ["AmazonBedrockTextEmbedder", "AmazonBedrockDocumentEmbedder"] +__all__ = ["AmazonBedrockDocumentEmbedder", "AmazonBedrockTextEmbedder"] diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/embedders/amazon_bedrock/document_embedder.py b/integrations/amazon_bedrock/src/haystack_integrations/components/embedders/amazon_bedrock/document_embedder.py index 1b8fde124..f2906c00d 100755 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/embedders/amazon_bedrock/document_embedder.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/embedders/amazon_bedrock/document_embedder.py @@ -236,7 +236,7 @@ def run(self, documents: List[Document]): - `documents`: The `Document`s with the `embedding` field populated. :raises AmazonBedrockInferenceError: If the inference fails. """ - if not isinstance(documents, list) or documents and not isinstance(documents[0], Document): + if not isinstance(documents, list) or (documents and not isinstance(documents[0], Document)): msg = ( "AmazonBedrockDocumentEmbedder expects a list of Documents as input." "In case you want to embed a string, please use the AmazonBedrockTextEmbedder." diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/__init__.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/__init__.py index 2d33beb42..ab3f0dfd5 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/__init__.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/__init__.py @@ -4,4 +4,4 @@ from .chat.chat_generator import AmazonBedrockChatGenerator from .generator import AmazonBedrockGenerator -__all__ = ["AmazonBedrockGenerator", "AmazonBedrockChatGenerator"] +__all__ = ["AmazonBedrockChatGenerator", "AmazonBedrockGenerator"] diff --git a/integrations/anthropic/src/haystack_integrations/components/generators/anthropic/__init__.py b/integrations/anthropic/src/haystack_integrations/components/generators/anthropic/__init__.py index 0bd29898e..12c588dc4 100644 --- a/integrations/anthropic/src/haystack_integrations/components/generators/anthropic/__init__.py +++ b/integrations/anthropic/src/haystack_integrations/components/generators/anthropic/__init__.py @@ -5,4 +5,4 @@ from .chat.vertex_chat_generator import AnthropicVertexChatGenerator from .generator import AnthropicGenerator -__all__ = ["AnthropicGenerator", "AnthropicChatGenerator", "AnthropicVertexChatGenerator"] +__all__ = ["AnthropicChatGenerator", "AnthropicGenerator", "AnthropicVertexChatGenerator"] diff --git a/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/__init__.py b/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/__init__.py index ca0ea7554..dcee0e622 100644 --- a/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/__init__.py +++ b/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/__init__.py @@ -4,4 +4,4 @@ from .document_store import DEFAULT_VECTOR_SEARCH, AzureAISearchDocumentStore from .filters import _normalize_filters -__all__ = ["AzureAISearchDocumentStore", "DEFAULT_VECTOR_SEARCH", "_normalize_filters"] +__all__ = ["DEFAULT_VECTOR_SEARCH", "AzureAISearchDocumentStore", "_normalize_filters"] diff --git a/integrations/chroma/src/haystack_integrations/components/retrievers/chroma/__init__.py b/integrations/chroma/src/haystack_integrations/components/retrievers/chroma/__init__.py index 53120c97c..e240ba136 100644 --- a/integrations/chroma/src/haystack_integrations/components/retrievers/chroma/__init__.py +++ b/integrations/chroma/src/haystack_integrations/components/retrievers/chroma/__init__.py @@ -1,3 +1,3 @@ from .retriever import ChromaEmbeddingRetriever, ChromaQueryTextRetriever -__all__ = ["ChromaQueryTextRetriever", "ChromaEmbeddingRetriever"] +__all__ = ["ChromaEmbeddingRetriever", "ChromaQueryTextRetriever"] diff --git a/integrations/cohere/src/haystack_integrations/components/embedders/cohere/document_embedder.py b/integrations/cohere/src/haystack_integrations/components/embedders/cohere/document_embedder.py index 3201168a8..d311662fe 100644 --- a/integrations/cohere/src/haystack_integrations/components/embedders/cohere/document_embedder.py +++ b/integrations/cohere/src/haystack_integrations/components/embedders/cohere/document_embedder.py @@ -146,7 +146,7 @@ def run(self, documents: List[Document]): - `meta`: metadata about the embedding process. :raises TypeError: if the input is not a list of `Documents`. """ - if not isinstance(documents, list) or documents and not isinstance(documents[0], Document): + if not isinstance(documents, list) or (documents and not isinstance(documents[0], Document)): msg = ( "CohereDocumentEmbedder expects a list of Documents as input." "In case you want to embed a string, please use the CohereTextEmbedder." diff --git a/integrations/cohere/src/haystack_integrations/components/generators/cohere/__init__.py b/integrations/cohere/src/haystack_integrations/components/generators/cohere/__init__.py index 93c0947e4..7d50682e8 100644 --- a/integrations/cohere/src/haystack_integrations/components/generators/cohere/__init__.py +++ b/integrations/cohere/src/haystack_integrations/components/generators/cohere/__init__.py @@ -4,4 +4,4 @@ from .chat.chat_generator import CohereChatGenerator from .generator import CohereGenerator -__all__ = ["CohereGenerator", "CohereChatGenerator"] +__all__ = ["CohereChatGenerator", "CohereGenerator"] diff --git a/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/__init__.py b/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/__init__.py index e943a8ca1..d73c29766 100644 --- a/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/__init__.py +++ b/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/__init__.py @@ -8,7 +8,7 @@ __all__ = [ "FastembedDocumentEmbedder", - "FastembedTextEmbedder", "FastembedSparseDocumentEmbedder", "FastembedSparseTextEmbedder", + "FastembedTextEmbedder", ] diff --git a/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/fastembed_document_embedder.py b/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/fastembed_document_embedder.py index 8b63582c5..b064173fe 100644 --- a/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/fastembed_document_embedder.py +++ b/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/fastembed_document_embedder.py @@ -158,7 +158,7 @@ def run(self, documents: List[Document]): :returns: A dictionary with the following keys: - `documents`: List of Documents with each Document's `embedding` field set to the computed embeddings. """ - if not isinstance(documents, list) or documents and not isinstance(documents[0], Document): + if not isinstance(documents, list) or (documents and not isinstance(documents[0], Document)): msg = ( "FastembedDocumentEmbedder expects a list of Documents as input. " "In case you want to embed a list of strings, please use the FastembedTextEmbedder." diff --git a/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/fastembed_sparse_document_embedder.py b/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/fastembed_sparse_document_embedder.py index a30d43cf4..fb3df9162 100644 --- a/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/fastembed_sparse_document_embedder.py +++ b/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/fastembed_sparse_document_embedder.py @@ -150,7 +150,7 @@ def run(self, documents: List[Document]): - `documents`: List of Documents with each Document's `sparse_embedding` field set to the computed embeddings. """ - if not isinstance(documents, list) or documents and not isinstance(documents[0], Document): + if not isinstance(documents, list) or (documents and not isinstance(documents[0], Document)): msg = ( "FastembedSparseDocumentEmbedder expects a list of Documents as input. " "In case you want to embed a list of strings, please use the FastembedTextEmbedder." diff --git a/integrations/fastembed/src/haystack_integrations/components/rankers/fastembed/ranker.py b/integrations/fastembed/src/haystack_integrations/components/rankers/fastembed/ranker.py index 8f077a30c..370344df5 100644 --- a/integrations/fastembed/src/haystack_integrations/components/rankers/fastembed/ranker.py +++ b/integrations/fastembed/src/haystack_integrations/components/rankers/fastembed/ranker.py @@ -157,7 +157,7 @@ def run(self, query: str, documents: List[Document], top_k: Optional[int] = None :raises ValueError: If `top_k` is not > 0. """ - if not isinstance(documents, list) or documents and not isinstance(documents[0], Document): + if not isinstance(documents, list) or (documents and not isinstance(documents[0], Document)): msg = "FastembedRanker expects a list of Documents as input. " raise TypeError(msg) if query == "": diff --git a/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/__init__.py b/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/__init__.py index 2b77c813f..c62129f9d 100644 --- a/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/__init__.py +++ b/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/__init__.py @@ -4,4 +4,4 @@ from .chat.gemini import GoogleAIGeminiChatGenerator from .gemini import GoogleAIGeminiGenerator -__all__ = ["GoogleAIGeminiGenerator", "GoogleAIGeminiChatGenerator"] +__all__ = ["GoogleAIGeminiChatGenerator", "GoogleAIGeminiGenerator"] diff --git a/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/__init__.py b/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/__init__.py index 07c2a5260..e5f556637 100644 --- a/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/__init__.py +++ b/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/__init__.py @@ -11,8 +11,8 @@ __all__ = [ "VertexAICodeGenerator", - "VertexAIGeminiGenerator", "VertexAIGeminiChatGenerator", + "VertexAIGeminiGenerator", "VertexAIImageCaptioner", "VertexAIImageGenerator", "VertexAIImageQA", diff --git a/integrations/google_vertex/tests/chat/test_gemini.py b/integrations/google_vertex/tests/chat/test_gemini.py index 73c99fe2f..614b83909 100644 --- a/integrations/google_vertex/tests/chat/test_gemini.py +++ b/integrations/google_vertex/tests/chat/test_gemini.py @@ -161,13 +161,13 @@ def test_to_dict_with_params(_mock_vertexai_init, _mock_generative_model): "name": "get_current_weather", "description": "Get the current weather in a given location", "parameters": { - "type_": "OBJECT", + "type": "OBJECT", "properties": { "location": { - "type_": "STRING", + "type": "STRING", "description": "The city and state, e.g. San Francisco, CA", }, - "unit": {"type_": "STRING", "enum": ["celsius", "fahrenheit"]}, + "unit": {"type": "STRING", "enum": ["celsius", "fahrenheit"]}, }, "required": ["location"], "property_ordering": ["location", "unit"], diff --git a/integrations/instructor_embedders/src/haystack_integrations/components/embedders/instructor_embedders/instructor_document_embedder.py b/integrations/instructor_embedders/src/haystack_integrations/components/embedders/instructor_embedders/instructor_document_embedder.py index 734798f46..c05c37733 100644 --- a/integrations/instructor_embedders/src/haystack_integrations/components/embedders/instructor_embedders/instructor_document_embedder.py +++ b/integrations/instructor_embedders/src/haystack_integrations/components/embedders/instructor_embedders/instructor_document_embedder.py @@ -158,7 +158,7 @@ def run(self, documents: List[Document]): param documents: A list of Documents to embed. """ - if not isinstance(documents, list) or documents and not isinstance(documents[0], Document): + if not isinstance(documents, list) or (documents and not isinstance(documents[0], Document)): msg = ( "InstructorDocumentEmbedder expects a list of Documents as input. " "In case you want to embed a list of strings, please use the InstructorTextEmbedder." diff --git a/integrations/jina/src/haystack_integrations/components/embedders/jina/document_embedder.py b/integrations/jina/src/haystack_integrations/components/embedders/jina/document_embedder.py index 715092b8a..103132faf 100644 --- a/integrations/jina/src/haystack_integrations/components/embedders/jina/document_embedder.py +++ b/integrations/jina/src/haystack_integrations/components/embedders/jina/document_embedder.py @@ -200,7 +200,7 @@ def run(self, documents: List[Document]): - `meta`: A dictionary with metadata including the model name and usage statistics. :raises TypeError: If the input is not a list of Documents. """ - if not isinstance(documents, list) or documents and not isinstance(documents[0], Document): + if not isinstance(documents, list) or (documents and not isinstance(documents[0], Document)): msg = ( "JinaDocumentEmbedder expects a list of Documents as input." "In case you want to embed a string, please use the JinaTextEmbedder." diff --git a/integrations/llama_cpp/src/haystack_integrations/components/generators/llama_cpp/__init__.py b/integrations/llama_cpp/src/haystack_integrations/components/generators/llama_cpp/__init__.py index 10b20d363..a85dbfd88 100644 --- a/integrations/llama_cpp/src/haystack_integrations/components/generators/llama_cpp/__init__.py +++ b/integrations/llama_cpp/src/haystack_integrations/components/generators/llama_cpp/__init__.py @@ -5,4 +5,4 @@ from .chat.chat_generator import LlamaCppChatGenerator from .generator import LlamaCppGenerator -__all__ = ["LlamaCppGenerator", "LlamaCppChatGenerator"] +__all__ = ["LlamaCppChatGenerator", "LlamaCppGenerator"] diff --git a/integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/__init__.py b/integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/__init__.py index 827ad7dc6..c6ecea7b1 100644 --- a/integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/__init__.py +++ b/integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/__init__.py @@ -6,4 +6,4 @@ from .text_embedder import NvidiaTextEmbedder from .truncate import EmbeddingTruncateMode -__all__ = ["NvidiaDocumentEmbedder", "NvidiaTextEmbedder", "EmbeddingTruncateMode"] +__all__ = ["EmbeddingTruncateMode", "NvidiaDocumentEmbedder", "NvidiaTextEmbedder"] diff --git a/integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/document_embedder.py b/integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/document_embedder.py index 6519efbab..b417fa737 100644 --- a/integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/document_embedder.py +++ b/integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/document_embedder.py @@ -242,7 +242,7 @@ def run(self, documents: List[Document]): if not self._initialized: msg = "The embedding model has not been loaded. Please call warm_up() before running." raise RuntimeError(msg) - elif not isinstance(documents, list) or documents and not isinstance(documents[0], Document): + elif not isinstance(documents, list) or (documents and not isinstance(documents[0], Document)): msg = ( "NvidiaDocumentEmbedder expects a list of Documents as input." "In case you want to embed a string, please use the NvidiaTextEmbedder." diff --git a/integrations/nvidia/src/haystack_integrations/utils/nvidia/__init__.py b/integrations/nvidia/src/haystack_integrations/utils/nvidia/__init__.py index f08cda6cd..0b69c8d24 100644 --- a/integrations/nvidia/src/haystack_integrations/utils/nvidia/__init__.py +++ b/integrations/nvidia/src/haystack_integrations/utils/nvidia/__init__.py @@ -5,4 +5,4 @@ from .nim_backend import Model, NimBackend from .utils import is_hosted, url_validation -__all__ = ["NimBackend", "Model", "is_hosted", "url_validation"] +__all__ = ["Model", "NimBackend", "is_hosted", "url_validation"] diff --git a/integrations/optimum/src/haystack_integrations/components/embedders/optimum/__init__.py b/integrations/optimum/src/haystack_integrations/components/embedders/optimum/__init__.py index 02e56b34c..ec0ecdef1 100644 --- a/integrations/optimum/src/haystack_integrations/components/embedders/optimum/__init__.py +++ b/integrations/optimum/src/haystack_integrations/components/embedders/optimum/__init__.py @@ -10,10 +10,10 @@ __all__ = [ "OptimumDocumentEmbedder", - "OptimumEmbedderOptimizationMode", "OptimumEmbedderOptimizationConfig", + "OptimumEmbedderOptimizationMode", "OptimumEmbedderPooling", - "OptimumEmbedderQuantizationMode", "OptimumEmbedderQuantizationConfig", + "OptimumEmbedderQuantizationMode", "OptimumTextEmbedder", ] diff --git a/integrations/optimum/src/haystack_integrations/components/embedders/optimum/optimum_document_embedder.py b/integrations/optimum/src/haystack_integrations/components/embedders/optimum/optimum_document_embedder.py index 27f533430..2016f3ffe 100644 --- a/integrations/optimum/src/haystack_integrations/components/embedders/optimum/optimum_document_embedder.py +++ b/integrations/optimum/src/haystack_integrations/components/embedders/optimum/optimum_document_embedder.py @@ -208,7 +208,7 @@ def run(self, documents: List[Document]): if not self._initialized: msg = "The embedding model has not been loaded. Please call warm_up() before running." raise RuntimeError(msg) - if not isinstance(documents, list) or documents and not isinstance(documents[0], Document): + if not isinstance(documents, list) or (documents and not isinstance(documents[0], Document)): msg = ( "OptimumDocumentEmbedder expects a list of Documents as input." " In case you want to embed a string, please use the OptimumTextEmbedder." diff --git a/integrations/qdrant/src/haystack_integrations/components/retrievers/qdrant/__init__.py b/integrations/qdrant/src/haystack_integrations/components/retrievers/qdrant/__init__.py index ed6422bfe..bbb7251d0 100644 --- a/integrations/qdrant/src/haystack_integrations/components/retrievers/qdrant/__init__.py +++ b/integrations/qdrant/src/haystack_integrations/components/retrievers/qdrant/__init__.py @@ -4,4 +4,4 @@ from .retriever import QdrantEmbeddingRetriever, QdrantHybridRetriever, QdrantSparseEmbeddingRetriever -__all__ = ("QdrantEmbeddingRetriever", "QdrantSparseEmbeddingRetriever", "QdrantHybridRetriever") +__all__ = ("QdrantEmbeddingRetriever", "QdrantHybridRetriever", "QdrantSparseEmbeddingRetriever") diff --git a/integrations/weaviate/src/haystack_integrations/document_stores/weaviate/__init__.py b/integrations/weaviate/src/haystack_integrations/document_stores/weaviate/__init__.py index 87c7b6b01..db084502b 100644 --- a/integrations/weaviate/src/haystack_integrations/document_stores/weaviate/__init__.py +++ b/integrations/weaviate/src/haystack_integrations/document_stores/weaviate/__init__.py @@ -5,10 +5,10 @@ from .document_store import WeaviateDocumentStore __all__ = [ - "WeaviateDocumentStore", "AuthApiKey", "AuthBearerToken", "AuthClientCredentials", "AuthClientPassword", "AuthCredentials", + "WeaviateDocumentStore", ] From c74bd539ea8e33adb3bd303a8df5de10d05917f5 Mon Sep 17 00:00:00 2001 From: David Basoco Date: Mon, 25 Nov 2024 13:05:08 +0100 Subject: [PATCH 109/229] Fix embedding retrieval top-k limit (#1210) Co-authored-by: David S. Batista --- .../document_stores/astra/astra_client.py | 3 +- .../astra/tests/test_embedding_retrieval.py | 48 +++++++++++++++++++ 2 files changed, 50 insertions(+), 1 deletion(-) create mode 100644 integrations/astra/tests/test_embedding_retrieval.py diff --git a/integrations/astra/src/haystack_integrations/document_stores/astra/astra_client.py b/integrations/astra/src/haystack_integrations/document_stores/astra/astra_client.py index 6f2289786..1a3481e0c 100644 --- a/integrations/astra/src/haystack_integrations/document_stores/astra/astra_client.py +++ b/integrations/astra/src/haystack_integrations/document_stores/astra/astra_client.py @@ -202,7 +202,7 @@ def _format_query_response(responses, include_metadata, include_values): return QueryResponse(final_res) def _query(self, vector, top_k, filters=None): - query = {"sort": {"$vector": vector}, "options": {"limit": top_k, "includeSimilarity": True}} + query = {"sort": {"$vector": vector}, "limit": top_k, "includeSimilarity": True} if filters is not None: query["filter"] = filters @@ -222,6 +222,7 @@ def find_documents(self, find_query): filter=find_query.get("filter"), sort=find_query.get("sort"), limit=find_query.get("limit"), + include_similarity=find_query.get("includeSimilarity"), projection={"*": 1}, ) diff --git a/integrations/astra/tests/test_embedding_retrieval.py b/integrations/astra/tests/test_embedding_retrieval.py new file mode 100644 index 000000000..bf23fe9f5 --- /dev/null +++ b/integrations/astra/tests/test_embedding_retrieval.py @@ -0,0 +1,48 @@ +import os + +import pytest +from haystack import Document +from haystack.document_stores.types import DuplicatePolicy + +from haystack_integrations.document_stores.astra import AstraDocumentStore + + +@pytest.mark.integration +@pytest.mark.skipif( + os.environ.get("ASTRA_DB_APPLICATION_TOKEN", "") == "", reason="ASTRA_DB_APPLICATION_TOKEN env var not set" +) +@pytest.mark.skipif(os.environ.get("ASTRA_DB_API_ENDPOINT", "") == "", reason="ASTRA_DB_API_ENDPOINT env var not set") +class TestEmbeddingRetrieval: + + @pytest.fixture + def document_store(self) -> AstraDocumentStore: + return AstraDocumentStore( + collection_name="haystack_integration", + duplicates_policy=DuplicatePolicy.OVERWRITE, + embedding_dimension=768, + ) + + @pytest.fixture(autouse=True) + def run_before_and_after_tests(self, document_store: AstraDocumentStore): + """ + Cleaning up document store + """ + document_store.delete_documents(delete_all=True) + assert document_store.count_documents() == 0 + + def test_search_with_top_k(self, document_store): + query_embedding = [0.1] * 768 + common_embedding = [0.8] * 768 + + documents = [Document(content=f"This is document number {i}", embedding=common_embedding) for i in range(0, 3)] + + document_store.write_documents(documents) + + top_k = 2 + + result = document_store.search(query_embedding, top_k) + + assert top_k == len(result) + + for document in result: + assert document.score is not None From ad6068894112548a9ed65f9048ce3530e2fdab2f Mon Sep 17 00:00:00 2001 From: HaystackBot Date: Mon, 25 Nov 2024 12:09:59 +0000 Subject: [PATCH 110/229] Update the changelog --- integrations/astra/CHANGELOG.md | 82 ++++++++++++++++++++++++++++----- 1 file changed, 70 insertions(+), 12 deletions(-) diff --git a/integrations/astra/CHANGELOG.md b/integrations/astra/CHANGELOG.md index fff6cb65f..6ad660a0e 100644 --- a/integrations/astra/CHANGELOG.md +++ b/integrations/astra/CHANGELOG.md @@ -1,16 +1,29 @@ # Changelog +## [integrations/astra-v0.9.4] - 2024-11-25 + +### 🌀 Miscellaneous + +- Fix: Astra - fix embedding retrieval top-k limit (#1210) + ## [integrations/astra-v0.10.0] - 2024-10-22 ### 🚀 Features - Update astradb integration for latest client library (#1145) -### ⚙️ Miscellaneous Tasks +### ⚙️ CI -- Update ruff linting scripts and settings (#1105) - Adopt uv as installer (#1142) +### 🧹 Chores + +- Update ruff linting scripts and settings (#1105) + +### 🌀 Miscellaneous + +- Fix: #1047 Remove count_documents from delete_documents (#1049) + ## [integrations/astra-v0.9.3] - 2024-09-12 ### 🐛 Bug Fixes @@ -22,8 +35,13 @@ - Do not retry tests in `hatch run test` command (#954) + ## [integrations/astra-v0.9.2] - 2024-07-22 +### 🌀 Miscellaneous + +- Normalize logical filter conditions (#874) + ## [integrations/astra-v0.9.1] - 2024-07-15 ### 🚀 Features @@ -37,27 +55,48 @@ - Fix typing checks - `Astra` - Fallback to default filter policy when deserializing retrievers without the init parameter (#896) -### ⚙️ Miscellaneous Tasks +### ⚙️ CI - Retry tests to reduce flakyness (#836) +### 🌀 Miscellaneous + +- Ci: install `pytest-rerunfailures` where needed; add retry config to `test-cov` script (#845) +- Fix: Incorrect astra not equal operator (#868) +- Chore: Minor retriever pydoc fix (#884) + ## [integrations/astra-v0.7.0] - 2024-05-15 ### 🐛 Bug Fixes - Make unit tests pass (#720) +### 🌀 Miscellaneous + +- Chore: change the pydoc renderer class (#718) +- [Astra DB] Explicit projection when reading from Astra DB (#733) + ## [integrations/astra-v0.6.0] - 2024-04-24 ### 🐛 Bug Fixes - Pass namespace in the docstore init (#683) +### 🌀 Miscellaneous + +- Chore: add license classifiers (#680) +- Bug fix for document_store.py (#618) + ## [integrations/astra-v0.5.1] - 2024-04-09 ### 🐛 Bug Fixes -- Fix haystack-ai pin (#649) +- Fix `haystack-ai` pins (#649) + +### 🌀 Miscellaneous + +- Remove references to Python 3.7 (#601) +- Make Document Stores initially skip `SparseEmbedding` (#606) ## [integrations/astra-v0.5.0] - 2024-03-18 @@ -67,9 +106,15 @@ - Small consistency improvements (#536) - Disable-class-def (#556) +### 🌀 Miscellaneous + +- Fix example code for Astra DB pipeline (#481) +- Make tests show coverage (#566) +- Astra DB: Add integration usage tracking (#568) + ## [integrations/astra-v0.4.2] - 2024-02-21 -### FIX +### 🌀 Miscellaneous - Proper name for the sort param (#454) @@ -78,9 +123,7 @@ ### 🐛 Bug Fixes - Fix order of API docs (#447) - -This PR will also push the docs to Readme -- Fix integration tests (#450) +- Astra: fix integration tests (#450) ## [integrations/astra-v0.4.0] - 2024-02-20 @@ -88,20 +131,35 @@ This PR will also push the docs to Readme - Update category slug (#442) +### 🌀 Miscellaneous + +- Update the Astra DB Integration to fit latest conventions (#428) + ## [integrations/astra-v0.3.0] - 2024-02-15 -## [integrations/astra-v0.2.0] - 2024-02-13 +### 🌀 Miscellaneous -### Astra +- Model_name_or_path > model (#418) +- [Astra] Change authentication parameters (#423) -- Generate api docs (#327) +## [integrations/astra-v0.2.0] - 2024-02-13 -### Refact +### 🌀 Miscellaneous - [**breaking**] Change import paths (#277) +- Generate api docs (#327) +- Astra: rename retriever (#399) ## [integrations/astra-v0.1.1] - 2024-01-18 +### 🌀 Miscellaneous + +- Update the import paths for beta5 (#235) + ## [integrations/astra-v0.1.0] - 2024-01-11 +### 🌀 Miscellaneous + +- Adding AstraDB as a DocumentStore (#144) + From e6b0ede64c2085844eab93bd37eb82038613c119 Mon Sep 17 00:00:00 2001 From: Sebastian Husch Lee Date: Mon, 25 Nov 2024 16:38:03 +0100 Subject: [PATCH 111/229] Update README.md (#1216) --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 43a2610ac..0f8b2f0ee 100644 --- a/README.md +++ b/README.md @@ -43,7 +43,7 @@ Please check out our [Contribution Guidelines](CONTRIBUTING.md) for all the deta | [llama-cpp-haystack](integrations/llama_cpp/) | Generator | [![PyPI - Version](https://img.shields.io/pypi/v/llama-cpp-haystack.svg?color=orange)](https://pypi.org/project/llama-cpp-haystack) | [![Test / llama-cpp](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/llama_cpp.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/llama_cpp.yml) | | [mistral-haystack](integrations/mistral/) | Embedder, Generator | [![PyPI - Version](https://img.shields.io/pypi/v/mistral-haystack.svg)](https://pypi.org/project/mistral-haystack) | [![Test / mistral](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/mistral.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/mistral.yml) | | [mongodb-atlas-haystack](integrations/mongodb_atlas/) | Document Store | [![PyPI - Version](https://img.shields.io/pypi/v/mongodb-atlas-haystack.svg?color=orange)](https://pypi.org/project/mongodb-atlas-haystack) | [![Test / mongodb-atlas](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/mongodb_atlas.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/mongodb_atlas.yml) | -| [nvidia-haystack](integrations/nvidia/) | Generator | [![PyPI - Version](https://img.shields.io/pypi/v/nvidia-haystack.svg?color=orange)](https://pypi.org/project/nvidia-haystack) | [![Test / nvidia](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/nvidia.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/nvidia.yml) | +| [nvidia-haystack](integrations/nvidia/) | Embedder, Generator, Ranker | [![PyPI - Version](https://img.shields.io/pypi/v/nvidia-haystack.svg?color=orange)](https://pypi.org/project/nvidia-haystack) | [![Test / nvidia](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/nvidia.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/nvidia.yml) | | [ollama-haystack](integrations/ollama/) | Embedder, Generator | [![PyPI - Version](https://img.shields.io/pypi/v/ollama-haystack.svg?color=orange)](https://pypi.org/project/ollama-haystack) | [![Test / ollama](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/ollama.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/ollama.yml) | | [opensearch-haystack](integrations/opensearch/) | Document Store | [![PyPI - Version](https://img.shields.io/pypi/v/opensearch-haystack.svg)](https://pypi.org/project/opensearch-haystack) | [![Test / opensearch](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/opensearch.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/opensearch.yml) | | [optimum-haystack](integrations/optimum/) | Embedder | [![PyPI - Version](https://img.shields.io/pypi/v/optimum-haystack.svg)](https://pypi.org/project/optimum-haystack) | [![Test / optimum](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/optimum.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/optimum.yml) | From 1119b7bc63c2698662dcc34158858054e8c51d68 Mon Sep 17 00:00:00 2001 From: Stefano Fiorucci Date: Mon, 25 Nov 2024 17:33:32 +0100 Subject: [PATCH 112/229] update labeler with Azure AI Search (#1218) --- .github/labeler.yml | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/.github/labeler.yml b/.github/labeler.yml index 85f15788f..4e2899e4b 100644 --- a/.github/labeler.yml +++ b/.github/labeler.yml @@ -14,6 +14,11 @@ integration:astra: - any-glob-to-any-file: "integrations/astra/**/*" - any-glob-to-any-file: ".github/workflows/astra.yml" +integration:azure-ai-search: + - changed-files: + - any-glob-to-any-file: "integrations/azure_ai_search/**/*" + - any-glob-to-any-file: ".github/workflows/azure_ai_search.yml" + integration:chroma: - changed-files: - any-glob-to-any-file: "integrations/chroma/**/*" From b770a2910bbd9d7d388519478a3237f1930b64c7 Mon Sep 17 00:00:00 2001 From: Stefano Fiorucci Date: Tue, 26 Nov 2024 16:47:18 +0100 Subject: [PATCH 113/229] jina reader: rename the output edge (#1217) --- .../haystack_integrations/components/connectors/jina/reader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/integrations/jina/src/haystack_integrations/components/connectors/jina/reader.py b/integrations/jina/src/haystack_integrations/components/connectors/jina/reader.py index eb53329f7..618cacb4e 100644 --- a/integrations/jina/src/haystack_integrations/components/connectors/jina/reader.py +++ b/integrations/jina/src/haystack_integrations/components/connectors/jina/reader.py @@ -103,7 +103,7 @@ def _json_to_document(self, data: dict) -> Document: document = Document(content=content, meta=data) return document - @component.output_types(document=List[Document]) + @component.output_types(documents=List[Document]) def run(self, query: str, headers: Optional[Dict[str, str]] = None): """ Process the query/URL using the Jina AI reader service. From 793598fa6966c11ac6807f67dbaf33222d4e3e21 Mon Sep 17 00:00:00 2001 From: HaystackBot Date: Tue, 26 Nov 2024 15:48:50 +0000 Subject: [PATCH 114/229] Update the changelog --- integrations/jina/CHANGELOG.md | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/integrations/jina/CHANGELOG.md b/integrations/jina/CHANGELOG.md index f65853d31..01de2abc1 100644 --- a/integrations/jina/CHANGELOG.md +++ b/integrations/jina/CHANGELOG.md @@ -1,5 +1,15 @@ # Changelog +## [integrations/jina-v0.5.1] - 2024-11-26 + +### 🧹 Chores + +- Fix linting/isort (#1215) + +### 🌀 Miscellaneous + +- Fix: `JinaReaderConnector` - fix the name of the output edge (#1217) + ## [integrations/jina-v0.5.0] - 2024-11-21 ### 🚀 Features From 15dacfacb87127bf1ee9c645ad52048995007715 Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Wed, 27 Nov 2024 10:33:12 +0100 Subject: [PATCH 115/229] Fix tracing_context_var lint errors (#1220) --- .../tracing/langfuse/tracer.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/integrations/langfuse/src/haystack_integrations/tracing/langfuse/tracer.py b/integrations/langfuse/src/haystack_integrations/tracing/langfuse/tracer.py index c1f8d4d93..d6f2535c7 100644 --- a/integrations/langfuse/src/haystack_integrations/tracing/langfuse/tracer.py +++ b/integrations/langfuse/src/haystack_integrations/tracing/langfuse/tracer.py @@ -42,7 +42,7 @@ # Context var used to keep track of tracing related info. # This mainly useful for parents spans. -tracing_context_var: ContextVar[Dict[Any, Any]] = ContextVar("tracing_context", default={}) +tracing_context_var: ContextVar[Dict[Any, Any]] = ContextVar("tracing_context") class LangfuseSpan(Span): @@ -147,15 +147,16 @@ def trace( operation_name=operation_name, ) # Create a new trace if no parent span is provided + context = tracing_context_var.get({}) span = LangfuseSpan( self._tracer.trace( name=self._name, public=self._public, - id=tracing_context_var.get().get("trace_id"), - user_id=tracing_context_var.get().get("user_id"), - session_id=tracing_context_var.get().get("session_id"), - tags=tracing_context_var.get().get("tags"), - version=tracing_context_var.get().get("version"), + id=context.get("trace_id"), + user_id=context.get("user_id"), + session_id=context.get("session_id"), + tags=context.get("tags"), + version=context.get("version"), ) ) elif tags.get(_COMPONENT_TYPE_KEY) in _ALL_SUPPORTED_GENERATORS: From 5de49be83b7edec29a8278154955405603674018 Mon Sep 17 00:00:00 2001 From: Daria Fokina Date: Wed, 27 Nov 2024 15:10:11 +0100 Subject: [PATCH 116/229] add anth-vertex-chat-gn docs (#1221) --- integrations/anthropic/pydoc/config.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/integrations/anthropic/pydoc/config.yml b/integrations/anthropic/pydoc/config.yml index 9c1e39daf..bd3811571 100644 --- a/integrations/anthropic/pydoc/config.yml +++ b/integrations/anthropic/pydoc/config.yml @@ -4,6 +4,7 @@ loaders: modules: [ "haystack_integrations.components.generators.anthropic.generator", "haystack_integrations.components.generators.anthropic.chat.chat_generator", + "haystack_integrations.components.generators.anthropic.chat.vertex_chat_generator", ] ignore_when_discovered: ["__init__"] processors: From 94a29cbc111b30992f7964b42efa019aa5149675 Mon Sep 17 00:00:00 2001 From: Stefano Fiorucci Date: Thu, 28 Nov 2024 10:12:13 +0100 Subject: [PATCH 117/229] use class methods to create ChatMessage (#1222) --- .../tests/test_chat_generator.py | 4 +-- .../components/generators/cohere/generator.py | 4 +-- .../tests/test_cohere_chat_generator.py | 12 ++++---- .../generators/google_ai/chat/gemini.py | 28 ++++++------------- .../generators/google_vertex/chat/gemini.py | 28 ++++++------------- .../ollama/tests/test_chat_generator.py | 27 ++++++------------ 6 files changed, 35 insertions(+), 68 deletions(-) diff --git a/integrations/amazon_bedrock/tests/test_chat_generator.py b/integrations/amazon_bedrock/tests/test_chat_generator.py index 185a34c8a..22594af2c 100644 --- a/integrations/amazon_bedrock/tests/test_chat_generator.py +++ b/integrations/amazon_bedrock/tests/test_chat_generator.py @@ -226,10 +226,8 @@ def test_long_prompt_is_not_truncated_when_truncate_false(mock_boto3_session): generator.model_adapter.get_responses = MagicMock( return_value=[ - ChatMessage( + ChatMessage.from_assistant( content="Some text", - role=ChatRole.ASSISTANT, - name=None, meta={ "model": "claude-3-sonnet-20240229", "index": 0, diff --git a/integrations/cohere/src/haystack_integrations/components/generators/cohere/generator.py b/integrations/cohere/src/haystack_integrations/components/generators/cohere/generator.py index 0eb65b368..e4eaf8670 100644 --- a/integrations/cohere/src/haystack_integrations/components/generators/cohere/generator.py +++ b/integrations/cohere/src/haystack_integrations/components/generators/cohere/generator.py @@ -5,7 +5,7 @@ from typing import Any, Callable, Dict, List, Optional from haystack import component -from haystack.dataclasses import ChatMessage, ChatRole +from haystack.dataclasses import ChatMessage from haystack.utils import Secret from .chat.chat_generator import CohereChatGenerator @@ -64,7 +64,7 @@ def run(self, prompt: str): - `replies`: A list of replies generated by the model. - `meta`: Information about the request. """ - chat_message = ChatMessage(content=prompt, role=ChatRole.USER, name="", meta={}) + chat_message = ChatMessage.from_user(prompt) # Note we have to call super() like this because of the way components are dynamically built with the decorator results = super(CohereGenerator, self).run([chat_message]) # noqa return {"replies": [results["replies"][0].content], "meta": [results["replies"][0].meta]} diff --git a/integrations/cohere/tests/test_cohere_chat_generator.py b/integrations/cohere/tests/test_cohere_chat_generator.py index 175a6d14b..b7cc0534a 100644 --- a/integrations/cohere/tests/test_cohere_chat_generator.py +++ b/integrations/cohere/tests/test_cohere_chat_generator.py @@ -27,7 +27,7 @@ def streaming_chunk(text: str): @pytest.fixture def chat_messages(): - return [ChatMessage(content="What's the capital of France", role=ChatRole.ASSISTANT, name=None)] + return [ChatMessage.from_assistant(content="What's the capital of France")] class TestCohereChatGenerator: @@ -164,7 +164,7 @@ def test_message_to_dict(self, chat_messages): ) @pytest.mark.integration def test_live_run(self): - chat_messages = [ChatMessage(content="What's the capital of France", role=ChatRole.USER, name="", meta={})] + chat_messages = [ChatMessage.from_user(content="What's the capital of France")] component = CohereChatGenerator(generation_kwargs={"temperature": 0.8}) results = component.run(chat_messages) assert len(results["replies"]) == 1 @@ -201,9 +201,7 @@ def __call__(self, chunk: StreamingChunk) -> None: callback = Callback() component = CohereChatGenerator(streaming_callback=callback) - results = component.run( - [ChatMessage(content="What's the capital of France? answer in a word", role=ChatRole.USER, name=None)] - ) + results = component.run([ChatMessage.from_user(content="What's the capital of France? answer in a word")]) assert len(results["replies"]) == 1 message: ChatMessage = results["replies"][0] @@ -224,7 +222,7 @@ def __call__(self, chunk: StreamingChunk) -> None: ) @pytest.mark.integration def test_live_run_with_connector(self): - chat_messages = [ChatMessage(content="What's the capital of France", role=ChatRole.USER, name="", meta={})] + chat_messages = [ChatMessage.from_user(content="What's the capital of France")] component = CohereChatGenerator(generation_kwargs={"temperature": 0.8}) results = component.run(chat_messages, generation_kwargs={"connectors": [{"id": "web-search"}]}) assert len(results["replies"]) == 1 @@ -249,7 +247,7 @@ def __call__(self, chunk: StreamingChunk) -> None: self.responses += chunk.content if chunk.content else "" callback = Callback() - chat_messages = [ChatMessage(content="What's the capital of France? answer in a word", role=None, name=None)] + chat_messages = [ChatMessage.from_user(content="What's the capital of France? answer in a word")] component = CohereChatGenerator(streaming_callback=callback) results = component.run(chat_messages, generation_kwargs={"connectors": [{"id": "web-search"}]}) diff --git a/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/chat/gemini.py b/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/chat/gemini.py index dbcab619d..ef7d583be 100644 --- a/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/chat/gemini.py +++ b/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/chat/gemini.py @@ -334,19 +334,14 @@ def _get_response(self, response_body: GenerateContentResponse) -> List[ChatMess for part in candidate.content.parts: if part.text != "": - replies.append( - ChatMessage(content=part.text, role=ChatRole.ASSISTANT, name=None, meta=candidate_metadata) - ) + replies.append(ChatMessage.from_assistant(content=part.text, meta=candidate_metadata)) elif part.function_call: candidate_metadata["function_call"] = part.function_call - replies.append( - ChatMessage( - content=dict(part.function_call.args.items()), - role=ChatRole.ASSISTANT, - name=part.function_call.name, - meta=candidate_metadata, - ) + new_message = ChatMessage.from_assistant( + content=dict(part.function_call.args.items()), meta=candidate_metadata ) + new_message.name = part.function_call.name + replies.append(new_message) return replies def _get_stream_response( @@ -368,18 +363,13 @@ def _get_stream_response( for part in candidate["content"]["parts"]: if "text" in part and part["text"] != "": content = part["text"] - replies.append(ChatMessage(content=content, role=ChatRole.ASSISTANT, meta=metadata, name=None)) + replies.append(ChatMessage.from_assistant(content=content, meta=metadata)) elif "function_call" in part and len(part["function_call"]) > 0: metadata["function_call"] = part["function_call"] content = part["function_call"]["args"] - replies.append( - ChatMessage( - content=content, - role=ChatRole.ASSISTANT, - name=part["function_call"]["name"], - meta=metadata, - ) - ) + new_message = ChatMessage.from_assistant(content=content, meta=metadata) + new_message.name = part["function_call"]["name"] + replies.append(new_message) streaming_callback(StreamingChunk(content=content, meta=metadata)) return replies diff --git a/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/chat/gemini.py b/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/chat/gemini.py index c52f76dc6..c94367b41 100644 --- a/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/chat/gemini.py +++ b/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/chat/gemini.py @@ -279,19 +279,14 @@ def _get_response(self, response_body: GenerationResponse) -> List[ChatMessage]: # Remove content from metadata metadata.pop("content", None) if part._raw_part.text != "": - replies.append( - ChatMessage(content=part._raw_part.text, role=ChatRole.ASSISTANT, name=None, meta=metadata) - ) + replies.append(ChatMessage.from_assistant(content=part._raw_part.text, meta=metadata)) elif part.function_call: metadata["function_call"] = part.function_call - replies.append( - ChatMessage( - content=dict(part.function_call.args.items()), - role=ChatRole.ASSISTANT, - name=part.function_call.name, - meta=metadata, - ) + new_message = ChatMessage.from_assistant( + content=dict(part.function_call.args.items()), meta=metadata ) + new_message.name = part.function_call.name + replies.append(new_message) return replies def _get_stream_response( @@ -313,18 +308,13 @@ def _get_stream_response( for part in candidate.content.parts: if part._raw_part.text: content = chunk.text - replies.append(ChatMessage(content, role=ChatRole.ASSISTANT, name=None, meta=metadata)) + replies.append(ChatMessage.from_assistant(content, meta=metadata)) elif part.function_call: metadata["function_call"] = part.function_call content = dict(part.function_call.args.items()) - replies.append( - ChatMessage( - content=content, - role=ChatRole.ASSISTANT, - name=part.function_call.name, - meta=metadata, - ) - ) + new_message = ChatMessage.from_assistant(content, meta=metadata) + new_message.name = part.function_call.name + replies.append(new_message) streaming_callback(StreamingChunk(content=content, meta=metadata)) return replies diff --git a/integrations/ollama/tests/test_chat_generator.py b/integrations/ollama/tests/test_chat_generator.py index b2b3fd927..0308f42ec 100644 --- a/integrations/ollama/tests/test_chat_generator.py +++ b/integrations/ollama/tests/test_chat_generator.py @@ -3,7 +3,7 @@ import pytest from haystack.components.generators.utils import print_streaming_chunk -from haystack.dataclasses import ChatMessage, ChatRole +from haystack.dataclasses import ChatMessage from ollama._types import ChatResponse, ResponseError from haystack_integrations.components.generators.ollama import OllamaChatGenerator @@ -128,16 +128,12 @@ def test_run_with_chat_history(self): chat_generator = OllamaChatGenerator() chat_history = [ - {"role": "user", "content": "What is the largest city in the United Kingdom by population?"}, - {"role": "assistant", "content": "London is the largest city in the United Kingdom by population"}, - {"role": "user", "content": "And what is the second largest?"}, + ChatMessage.from_user("What is the largest city in the United Kingdom by population?"), + ChatMessage.from_assistant("London is the largest city in the United Kingdom by population"), + ChatMessage.from_user("And what is the second largest?"), ] - chat_messages = [ - ChatMessage(role=ChatRole(message["role"]), content=message["content"], name=None) - for message in chat_history - ] - response = chat_generator.run(chat_messages) + response = chat_generator.run(chat_history) assert isinstance(response, dict) assert isinstance(response["replies"], list) @@ -159,17 +155,12 @@ def test_run_with_streaming(self): chat_generator = OllamaChatGenerator(streaming_callback=streaming_callback) chat_history = [ - {"role": "user", "content": "What is the largest city in the United Kingdom by population?"}, - {"role": "assistant", "content": "London is the largest city in the United Kingdom by population"}, - {"role": "user", "content": "And what is the second largest?"}, - ] - - chat_messages = [ - ChatMessage(role=ChatRole(message["role"]), content=message["content"], name=None) - for message in chat_history + ChatMessage.from_user("What is the largest city in the United Kingdom by population?"), + ChatMessage.from_assistant("London is the largest city in the United Kingdom by population"), + ChatMessage.from_user("And what is the second largest?"), ] - response = chat_generator.run(chat_messages) + response = chat_generator.run(chat_history) streaming_callback.assert_called() From 57516056f5143ce25fdc70f1390f2bc4f10952e1 Mon Sep 17 00:00:00 2001 From: Stefano Fiorucci Date: Thu, 28 Nov 2024 15:44:25 +0100 Subject: [PATCH 118/229] chore: Chroma - pin `tokenizers` (#1223) * try adding tokenizers to dependencies * pin tokenizers * fix * nicer format --- integrations/chroma/pyproject.toml | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/integrations/chroma/pyproject.toml b/integrations/chroma/pyproject.toml index cfe7a606e..c91cc6cb0 100644 --- a/integrations/chroma/pyproject.toml +++ b/integrations/chroma/pyproject.toml @@ -22,7 +22,12 @@ classifiers = [ "Programming Language :: Python :: Implementation :: CPython", "Programming Language :: Python :: Implementation :: PyPy", ] -dependencies = ["haystack-ai", "chromadb>=0.5.17", "typing_extensions>=4.8.0"] +dependencies = [ + "haystack-ai", + "chromadb>=0.5.17", + "typing_extensions>=4.8.0", + "tokenizers>=0.13.2,<=0.20.3" # TODO: remove when Chroma pins tokenizers internally +] [project.urls] Documentation = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/chroma#readme" From eb2cfb1a6a3f272023be776be9eb0cad83ba2420 Mon Sep 17 00:00:00 2001 From: latifboubyan Date: Thu, 28 Nov 2024 19:37:48 +0300 Subject: [PATCH 119/229] feat: `OllamaDocumentEmbedder` - allow batching embeddings (#1224) * use batch embeddings * extend embeddings with batch result * add batch_size parameter to OllamaDocumentEmbedder * use correct embed parameter * add unit test for ollama batch embed * refinements --------- Co-authored-by: David S. Batista Co-authored-by: anakin87 --- .../embedders/ollama/document_embedder.py | 36 ++++++++++++------- .../ollama/tests/test_document_embedder.py | 18 ++++++---- 2 files changed, 34 insertions(+), 20 deletions(-) diff --git a/integrations/ollama/src/haystack_integrations/components/embedders/ollama/document_embedder.py b/integrations/ollama/src/haystack_integrations/components/embedders/ollama/document_embedder.py index 2fab6c72f..8d2f5f505 100644 --- a/integrations/ollama/src/haystack_integrations/components/embedders/ollama/document_embedder.py +++ b/integrations/ollama/src/haystack_integrations/components/embedders/ollama/document_embedder.py @@ -36,6 +36,7 @@ def __init__( progress_bar: bool = True, meta_fields_to_embed: Optional[List[str]] = None, embedding_separator: str = "\n", + batch_size: int = 32, ): """ :param model: @@ -48,12 +49,24 @@ def __init__( [Ollama docs](https://github.com/jmorganca/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values). :param timeout: The number of seconds before throwing a timeout error from the Ollama API. + :param prefix: + A string to add at the beginning of each text. + :param suffix: + A string to add at the end of each text. + :param progress_bar: + If `True`, shows a progress bar when running. + :param meta_fields_to_embed: + List of metadata fields to embed along with the document text. + :param embedding_separator: + Separator used to concatenate the metadata fields to the document text. + :param batch_size: + Number of documents to process at once. """ self.timeout = timeout self.generation_kwargs = generation_kwargs or {} self.url = url self.model = model - self.batch_size = 1 # API only supports a single call at the moment + self.batch_size = batch_size self.progress_bar = progress_bar self.meta_fields_to_embed = meta_fields_to_embed self.embedding_separator = embedding_separator @@ -88,24 +101,19 @@ def _embed_batch( self, texts_to_embed: List[str], batch_size: int, generation_kwargs: Optional[Dict[str, Any]] = None ): """ - Ollama Embedding only allows single uploads, not batching. Currently the batch size is set to 1. - If this changes in the future, line 86 (the first line within the for loop), can contain: - batch = texts_to_embed[i + i + batch_size] + Internal method to embed a batch of texts. """ all_embeddings = [] - meta: Dict[str, Any] = {"model": ""} for i in tqdm( range(0, len(texts_to_embed), batch_size), disable=not self.progress_bar, desc="Calculating embeddings" ): - batch = texts_to_embed[i] # Single batch only - result = self._client.embeddings(model=self.model, prompt=batch, options=generation_kwargs).model_dump() - all_embeddings.append(result["embedding"]) + batch = texts_to_embed[i : i + batch_size] + result = self._client.embed(model=self.model, input=batch, options=generation_kwargs) + all_embeddings.extend(result["embeddings"]) - meta["model"] = self.model - - return all_embeddings, meta + return all_embeddings @component.output_types(documents=List[Document], meta=Dict[str, Any]) def run(self, documents: List[Document], generation_kwargs: Optional[Dict[str, Any]] = None): @@ -129,12 +137,14 @@ def run(self, documents: List[Document], generation_kwargs: Optional[Dict[str, A ) raise TypeError(msg) + generation_kwargs = generation_kwargs or self.generation_kwargs + texts_to_embed = self._prepare_texts_to_embed(documents=documents) - embeddings, meta = self._embed_batch( + embeddings = self._embed_batch( texts_to_embed=texts_to_embed, batch_size=self.batch_size, generation_kwargs=generation_kwargs ) for doc, emb in zip(documents, embeddings): doc.embedding = emb - return {"documents": documents, "meta": meta} + return {"documents": documents, "meta": {"model": self.model}} diff --git a/integrations/ollama/tests/test_document_embedder.py b/integrations/ollama/tests/test_document_embedder.py index 4fe3cfbb3..7d972e898 100644 --- a/integrations/ollama/tests/test_document_embedder.py +++ b/integrations/ollama/tests/test_document_embedder.py @@ -43,10 +43,14 @@ def import_text_in_embedder(self): @pytest.mark.integration def test_run(self): - embedder = OllamaDocumentEmbedder(model="nomic-embed-text") - list_of_docs = [Document(content="This is a document containing some text.")] - reply = embedder.run(list_of_docs) - - assert isinstance(reply, dict) - assert all(isinstance(element, float) for element in reply["documents"][0].embedding) - assert reply["meta"]["model"] == "nomic-embed-text" + embedder = OllamaDocumentEmbedder(model="nomic-embed-text", batch_size=2) + list_of_docs = [ + Document(content="Llamas are amazing animals known for their soft wool and gentle demeanor."), + Document(content="The Andes mountains are the natural habitat of many llamas."), + Document(content="Llamas have been used as pack animals for centuries, especially in South America."), + ] + result = embedder.run(list_of_docs) + assert result["meta"]["model"] == "nomic-embed-text" + documents = result["documents"] + assert len(documents) == 3 + assert all(isinstance(element, float) for document in documents for element in document.embedding) From 319b64b4dd0be883580fc0956f488f663c6b99a6 Mon Sep 17 00:00:00 2001 From: HaystackBot Date: Thu, 28 Nov 2024 16:39:03 +0000 Subject: [PATCH 120/229] Update the changelog --- integrations/ollama/CHANGELOG.md | 68 ++++++++++++++++++++++++++++---- 1 file changed, 60 insertions(+), 8 deletions(-) diff --git a/integrations/ollama/CHANGELOG.md b/integrations/ollama/CHANGELOG.md index 29c8dd910..9e2e0a0cb 100644 --- a/integrations/ollama/CHANGELOG.md +++ b/integrations/ollama/CHANGELOG.md @@ -1,27 +1,45 @@ # Changelog +## [integrations/ollama-v2.1.0] - 2024-11-28 + +### 🚀 Features + +- `OllamaDocumentEmbedder` - allow batching embeddings (#1224) + +### 🌀 Miscellaneous + +- Chore: update changelog for `ollama-haystack==2.0.0` (#1214) +- Chore: use class methods to create `ChatMessage` (#1222) + ## [integrations/ollama-v2.0.0] - 2024-11-22 ### 🐛 Bug Fixes - Adapt to Ollama client 0.4.0 (#1209) +### ⚙️ CI + +- Adopt uv as installer (#1142) + + ## [integrations/ollama-v1.1.0] - 2024-10-11 ### 🚀 Features - Add `keep_alive` parameter to Ollama Generators (#1131) -### ⚙️ Miscellaneous Tasks +### 🧹 Chores - Update ruff linting scripts and settings (#1105) + ## [integrations/ollama-v1.0.1] - 2024-09-26 ### 🐛 Bug Fixes - Ollama Chat Generator - add missing `to_dict` and `from_dict` methods (#1110) + ## [integrations/ollama-v1.0.0] - 2024-09-07 ### 🐛 Bug Fixes @@ -36,31 +54,47 @@ - Do not retry tests in `hatch run test` command (#954) -### ⚙️ Miscellaneous Tasks +### ⚙️ CI - Retry tests to reduce flakyness (#836) + +### 🧹 Chores + - Update ruff invocation to include check parameter (#853) +### 🌀 Miscellaneous + +- Ci: install `pytest-rerunfailures` where needed; add retry config to `test-cov` script (#845) +- Chore: ollama - ruff update, don't ruff tests (#985) + ## [integrations/ollama-v0.0.7] - 2024-05-31 ### 🚀 Features - Add streaming support to OllamaChatGenerator (#757) +### 🌀 Miscellaneous + +- Chore: add license classifiers (#680) +- Chore: change the pydoc renderer class (#718) + ## [integrations/ollama-v0.0.6] - 2024-04-18 ### 📚 Documentation - Disable-class-def (#556) -### ⚙️ Miscellaneous Tasks +### 🧹 Chores - Update docstrings (#499) -### Ollama +### 🌀 Miscellaneous +- Update API docs adding embedders (#494) - Change testing workflow (#551) +- Remove references to Python 3.7 (#601) - Add ollama embedder example (#669) +- Fix: change ollama output name to 'meta' (#670) ## [integrations/ollama-v0.0.5] - 2024-02-28 @@ -68,24 +102,42 @@ - Fix order of API docs (#447) -This PR will also push the docs to Readme - ### 📚 Documentation - Update category slug (#442) -### ⚙️ Miscellaneous Tasks +### 🧹 Chores - Use `serialize_callable` instead of `serialize_callback_handler` in Ollama (#461) +### 🌀 Miscellaneous + +- Ollama document embedder (#400) +- Changed Default Ollama Embedding models to supported model: nomic-embed-text (#490) + ## [integrations/ollama-v0.0.4] - 2024-02-12 -### Ollama +### 🌀 Miscellaneous +- Ollama: add license (#219) - Generate api docs (#332) +- Ollama Text Embedder with new format (#252) +- Support for streaming ollama generator (#280) ## [integrations/ollama-v0.0.3] - 2024-01-16 +### 🌀 Miscellaneous + +- Docs: Ollama docstrings update (#171) +- Add example of OllamaGenerator (#170) +- Ollama Chat Generator (#176) +- Ollama: improve test (#191) +- Mount Ollama in haystack_integrations (#216) + ## [integrations/ollama-v0.0.1] - 2024-01-03 +### 🌀 Miscellaneous + +- Ollama Integration (#132) + From 798fb98fddb6b1c26a32c426c65f764efc42b787 Mon Sep 17 00:00:00 2001 From: Anes Benmerzoug Date: Tue, 3 Dec 2024 16:59:12 +0100 Subject: [PATCH 121/229] fix: Allow passing boto3 config to all AWS Bedrock classes (#1166) * Allow passing boto3 config to AmazonBedrockChatGenerator * Allow passing boto3 config to AmazonBedrockDocumentEmbedder * Allow passing boto3 config to AmazonBedrockTextEmbedder * Remove whitespace from blank line * Reorder setting attributes for readability * Remove blank line * fix: adapt our implementation to breaking changes in Chroma 0.5.17 (#1165) * fix chroma breaking changes * improve warning * better warning * Update the changelog * Parametrize to_dict and from_dict tests with boto3_config --------- Co-authored-by: Stefano Fiorucci Co-authored-by: HaystackBot Co-authored-by: David S. Batista --- .../amazon_bedrock/document_embedder.py | 34 +++++++++++------- .../embedders/amazon_bedrock/text_embedder.py | 26 +++++++++----- .../amazon_bedrock/chat/chat_generator.py | 19 +++++++--- .../tests/test_chat_generator.py | 28 +++++++++++++-- .../tests/test_document_embedder.py | 27 ++++++++++++-- .../amazon_bedrock/tests/test_generator.py | 36 ++++++++++++++----- .../tests/test_text_embedder.py | 5 +++ 7 files changed, 135 insertions(+), 40 deletions(-) diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/embedders/amazon_bedrock/document_embedder.py b/integrations/amazon_bedrock/src/haystack_integrations/components/embedders/amazon_bedrock/document_embedder.py index f2906c00d..f15601f57 100755 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/embedders/amazon_bedrock/document_embedder.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/embedders/amazon_bedrock/document_embedder.py @@ -2,6 +2,7 @@ import logging from typing import Any, Dict, List, Literal, Optional +from botocore.config import Config from botocore.exceptions import ClientError from haystack import component, default_from_dict, default_to_dict from haystack.dataclasses import Document @@ -73,6 +74,7 @@ def __init__( progress_bar: bool = True, meta_fields_to_embed: Optional[List[str]] = None, embedding_separator: str = "\n", + boto3_config: Optional[Dict[str, Any]] = None, **kwargs, ): """ @@ -98,6 +100,7 @@ def __init__( to keep the logs clean. :param meta_fields_to_embed: List of meta fields that should be embedded along with the Document text. :param embedding_separator: Separator used to concatenate the meta fields to the Document text. + :param boto3_config: The configuration for the boto3 client. :param kwargs: Additional parameters to pass for model inference. For example, `input_type` and `truncate` for Cohere models. :raises ValueError: If the model is not supported. @@ -110,6 +113,19 @@ def __init__( ) raise ValueError(msg) + self.model = model + self.aws_access_key_id = aws_access_key_id + self.aws_secret_access_key = aws_secret_access_key + self.aws_session_token = aws_session_token + self.aws_region_name = aws_region_name + self.aws_profile_name = aws_profile_name + self.batch_size = batch_size + self.progress_bar = progress_bar + self.meta_fields_to_embed = meta_fields_to_embed or [] + self.embedding_separator = embedding_separator + self.boto3_config = boto3_config + self.kwargs = kwargs + def resolve_secret(secret: Optional[Secret]) -> Optional[str]: return secret.resolve_value() if secret else None @@ -121,7 +137,10 @@ def resolve_secret(secret: Optional[Secret]) -> Optional[str]: aws_region_name=resolve_secret(aws_region_name), aws_profile_name=resolve_secret(aws_profile_name), ) - self._client = session.client("bedrock-runtime") + config: Optional[Config] = None + if self.boto3_config: + config = Config(**self.boto3_config) + self._client = session.client("bedrock-runtime", config=config) except Exception as exception: msg = ( "Could not connect to Amazon Bedrock. Make sure the AWS environment is configured correctly. " @@ -129,18 +148,6 @@ def resolve_secret(secret: Optional[Secret]) -> Optional[str]: ) raise AmazonBedrockConfigurationError(msg) from exception - self.model = model - self.aws_access_key_id = aws_access_key_id - self.aws_secret_access_key = aws_secret_access_key - self.aws_session_token = aws_session_token - self.aws_region_name = aws_region_name - self.aws_profile_name = aws_profile_name - self.batch_size = batch_size - self.progress_bar = progress_bar - self.meta_fields_to_embed = meta_fields_to_embed or [] - self.embedding_separator = embedding_separator - self.kwargs = kwargs - def _prepare_texts_to_embed(self, documents: List[Document]) -> List[str]: """ Prepare the texts to embed by concatenating the Document text with the metadata fields to embed. @@ -269,6 +276,7 @@ def to_dict(self) -> Dict[str, Any]: progress_bar=self.progress_bar, meta_fields_to_embed=self.meta_fields_to_embed, embedding_separator=self.embedding_separator, + boto3_config=self.boto3_config, **self.kwargs, ) diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/embedders/amazon_bedrock/text_embedder.py b/integrations/amazon_bedrock/src/haystack_integrations/components/embedders/amazon_bedrock/text_embedder.py index 0cceda92f..0acd51da5 100755 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/embedders/amazon_bedrock/text_embedder.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/embedders/amazon_bedrock/text_embedder.py @@ -2,6 +2,7 @@ import logging from typing import Any, Dict, List, Literal, Optional +from botocore.config import Config from botocore.exceptions import ClientError from haystack import component, default_from_dict, default_to_dict from haystack.utils.auth import Secret, deserialize_secrets_inplace @@ -62,6 +63,7 @@ def __init__( aws_session_token: Optional[Secret] = Secret.from_env_var("AWS_SESSION_TOKEN", strict=False), # noqa: B008 aws_region_name: Optional[Secret] = Secret.from_env_var("AWS_DEFAULT_REGION", strict=False), # noqa: B008 aws_profile_name: Optional[Secret] = Secret.from_env_var("AWS_PROFILE", strict=False), # noqa: B008 + boto3_config: Optional[Dict[str, Any]] = None, **kwargs, ): """ @@ -81,6 +83,7 @@ def __init__( :param aws_session_token: AWS session token. :param aws_region_name: AWS region name. :param aws_profile_name: AWS profile name. + :param boto3_config: The configuration for the boto3 client. :param kwargs: Additional parameters to pass for model inference. For example, `input_type` and `truncate` for Cohere models. :raises ValueError: If the model is not supported. @@ -92,6 +95,15 @@ def __init__( ) raise ValueError(msg) + self.model = model + self.aws_access_key_id = aws_access_key_id + self.aws_secret_access_key = aws_secret_access_key + self.aws_session_token = aws_session_token + self.aws_region_name = aws_region_name + self.aws_profile_name = aws_profile_name + self.boto3_config = boto3_config + self.kwargs = kwargs + def resolve_secret(secret: Optional[Secret]) -> Optional[str]: return secret.resolve_value() if secret else None @@ -103,7 +115,10 @@ def resolve_secret(secret: Optional[Secret]) -> Optional[str]: aws_region_name=resolve_secret(aws_region_name), aws_profile_name=resolve_secret(aws_profile_name), ) - self._client = session.client("bedrock-runtime") + config: Optional[Config] = None + if self.boto3_config: + config = Config(**self.boto3_config) + self._client = session.client("bedrock-runtime", config=config) except Exception as exception: msg = ( "Could not connect to Amazon Bedrock. Make sure the AWS environment is configured correctly. " @@ -111,14 +126,6 @@ def resolve_secret(secret: Optional[Secret]) -> Optional[str]: ) raise AmazonBedrockConfigurationError(msg) from exception - self.model = model - self.aws_access_key_id = aws_access_key_id - self.aws_secret_access_key = aws_secret_access_key - self.aws_session_token = aws_session_token - self.aws_region_name = aws_region_name - self.aws_profile_name = aws_profile_name - self.kwargs = kwargs - @component.output_types(embedding=List[float]) def run(self, text: str): """Embeds the input text using the Amazon Bedrock model. @@ -185,6 +192,7 @@ def to_dict(self) -> Dict[str, Any]: aws_region_name=self.aws_region_name.to_dict() if self.aws_region_name else None, aws_profile_name=self.aws_profile_name.to_dict() if self.aws_profile_name else None, model=self.model, + boto3_config=self.boto3_config, **self.kwargs, ) diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py index 6bb3cc301..183198bce 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py @@ -3,6 +3,7 @@ import re from typing import Any, Callable, ClassVar, Dict, List, Optional, Type +from botocore.config import Config from botocore.exceptions import ClientError from haystack import component, default_from_dict, default_to_dict from haystack.dataclasses import ChatMessage, StreamingChunk @@ -77,6 +78,7 @@ def __init__( stop_words: Optional[List[str]] = None, streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, truncate: Optional[bool] = True, + boto3_config: Optional[Dict[str, Any]] = None, ): """ Initializes the `AmazonBedrockChatGenerator` with the provided parameters. The parameters are passed to the @@ -110,6 +112,11 @@ def __init__( [StreamingChunk](https://docs.haystack.deepset.ai/docs/data-classes#streamingchunk) object and switches the streaming mode on. :param truncate: Whether to truncate the prompt messages or not. + :param boto3_config: The configuration for the boto3 client. + + :raises ValueError: If the model name is empty or None. + :raises AmazonBedrockConfigurationError: If the AWS environment is not configured correctly or the model is + not supported. """ if not model: msg = "'model' cannot be None or empty string" @@ -120,7 +127,10 @@ def __init__( self.aws_session_token = aws_session_token self.aws_region_name = aws_region_name self.aws_profile_name = aws_profile_name + self.stop_words = stop_words or [] + self.streaming_callback = streaming_callback self.truncate = truncate + self.boto3_config = boto3_config # get the model adapter for the given model model_adapter_cls = self.get_model_adapter(model=model) @@ -141,7 +151,10 @@ def resolve_secret(secret: Optional[Secret]) -> Optional[str]: aws_region_name=resolve_secret(aws_region_name), aws_profile_name=resolve_secret(aws_profile_name), ) - self.client = session.client("bedrock-runtime") + config: Optional[Config] = None + if self.boto3_config: + config = Config(**self.boto3_config) + self.client = session.client("bedrock-runtime", config=config) except Exception as exception: msg = ( "Could not connect to Amazon Bedrock. Make sure the AWS environment is configured correctly. " @@ -149,9 +162,6 @@ def resolve_secret(secret: Optional[Secret]) -> Optional[str]: ) raise AmazonBedrockConfigurationError(msg) from exception - self.stop_words = stop_words or [] - self.streaming_callback = streaming_callback - @component.output_types(replies=List[ChatMessage]) def run( self, @@ -256,6 +266,7 @@ def to_dict(self) -> Dict[str, Any]: generation_kwargs=self.model_adapter.generation_kwargs, streaming_callback=callback_name, truncate=self.truncate, + boto3_config=self.boto3_config, ) @classmethod diff --git a/integrations/amazon_bedrock/tests/test_chat_generator.py b/integrations/amazon_bedrock/tests/test_chat_generator.py index 22594af2c..8d6a5c3ee 100644 --- a/integrations/amazon_bedrock/tests/test_chat_generator.py +++ b/integrations/amazon_bedrock/tests/test_chat_generator.py @@ -1,7 +1,7 @@ import json import logging import os -from typing import Optional, Type +from typing import Any, Dict, Optional, Type from unittest.mock import MagicMock, patch import pytest @@ -26,7 +26,16 @@ ] -def test_to_dict(mock_boto3_session): +@pytest.mark.parametrize( + "boto3_config", + [ + None, + { + "read_timeout": 1000, + }, + ], +) +def test_to_dict(mock_boto3_session: Any, boto3_config: Optional[Dict[str, Any]]): """ Test that the to_dict method returns the correct dictionary without aws credentials """ @@ -34,6 +43,7 @@ def test_to_dict(mock_boto3_session): model="anthropic.claude-v2", generation_kwargs={"temperature": 0.7}, streaming_callback=print_streaming_chunk, + boto3_config=boto3_config, ) expected_dict = { "type": KLASS, @@ -48,13 +58,23 @@ def test_to_dict(mock_boto3_session): "stop_words": [], "streaming_callback": "haystack.components.generators.utils.print_streaming_chunk", "truncate": True, + "boto3_config": boto3_config, }, } assert generator.to_dict() == expected_dict -def test_from_dict(mock_boto3_session): +@pytest.mark.parametrize( + "boto3_config", + [ + None, + { + "read_timeout": 1000, + }, + ], +) +def test_from_dict(mock_boto3_session: Any, boto3_config: Optional[Dict[str, Any]]): """ Test that the from_dict method returns the correct object """ @@ -71,12 +91,14 @@ def test_from_dict(mock_boto3_session): "generation_kwargs": {"temperature": 0.7}, "streaming_callback": "haystack.components.generators.utils.print_streaming_chunk", "truncate": True, + "boto3_config": boto3_config, }, } ) assert generator.model == "anthropic.claude-v2" assert generator.model_adapter.generation_kwargs == {"temperature": 0.7} assert generator.streaming_callback == print_streaming_chunk + assert generator.boto3_config == boto3_config def test_default_constructor(mock_boto3_session, set_env_variables): diff --git a/integrations/amazon_bedrock/tests/test_document_embedder.py b/integrations/amazon_bedrock/tests/test_document_embedder.py index 9856c97bb..05672e9c7 100644 --- a/integrations/amazon_bedrock/tests/test_document_embedder.py +++ b/integrations/amazon_bedrock/tests/test_document_embedder.py @@ -1,4 +1,5 @@ import io +from typing import Any, Dict, Optional from unittest.mock import patch import pytest @@ -66,10 +67,20 @@ def test_connection_error(self, mock_boto3_session): input_type="fake_input_type", ) - def test_to_dict(self, mock_boto3_session): + @pytest.mark.parametrize( + "boto3_config", + [ + None, + { + "read_timeout": 1000, + }, + ], + ) + def test_to_dict(self, mock_boto3_session: Any, boto3_config: Optional[Dict[str, Any]]): embedder = AmazonBedrockDocumentEmbedder( model="cohere.embed-english-v3", input_type="search_document", + boto3_config=boto3_config, ) expected_dict = { @@ -86,12 +97,22 @@ def test_to_dict(self, mock_boto3_session): "progress_bar": True, "meta_fields_to_embed": [], "embedding_separator": "\n", + "boto3_config": boto3_config, }, } assert embedder.to_dict() == expected_dict - def test_from_dict(self, mock_boto3_session): + @pytest.mark.parametrize( + "boto3_config", + [ + None, + { + "read_timeout": 1000, + }, + ], + ) + def test_from_dict(self, mock_boto3_session: Any, boto3_config: Optional[Dict[str, Any]]): data = { "type": TYPE, "init_parameters": { @@ -106,6 +127,7 @@ def test_from_dict(self, mock_boto3_session): "progress_bar": True, "meta_fields_to_embed": [], "embedding_separator": "\n", + "boto3_config": boto3_config, }, } @@ -117,6 +139,7 @@ def test_from_dict(self, mock_boto3_session): assert embedder.progress_bar assert embedder.meta_fields_to_embed == [] assert embedder.embedding_separator == "\n" + assert embedder.boto3_config == boto3_config def test_init_invalid_model(self): with pytest.raises(ValueError): diff --git a/integrations/amazon_bedrock/tests/test_generator.py b/integrations/amazon_bedrock/tests/test_generator.py index be645218e..54b185da5 100644 --- a/integrations/amazon_bedrock/tests/test_generator.py +++ b/integrations/amazon_bedrock/tests/test_generator.py @@ -1,4 +1,4 @@ -from typing import Optional, Type +from typing import Any, Dict, Optional, Type from unittest.mock import MagicMock, call, patch import pytest @@ -17,11 +17,22 @@ ) -def test_to_dict(mock_boto3_session): +@pytest.mark.parametrize( + "boto3_config", + [ + None, + { + "read_timeout": 1000, + }, + ], +) +def test_to_dict(mock_boto3_session: Any, boto3_config: Optional[Dict[str, Any]]): """ Test that the to_dict method returns the correct dictionary without aws credentials """ - generator = AmazonBedrockGenerator(model="anthropic.claude-v2", max_length=99, truncate=False, temperature=10) + generator = AmazonBedrockGenerator( + model="anthropic.claude-v2", max_length=99, truncate=False, temperature=10, boto3_config=boto3_config + ) expected_dict = { "type": "haystack_integrations.components.generators.amazon_bedrock.generator.AmazonBedrockGenerator", @@ -36,14 +47,23 @@ def test_to_dict(mock_boto3_session): "truncate": False, "temperature": 10, "streaming_callback": None, - "boto3_config": None, + "boto3_config": boto3_config, }, } assert generator.to_dict() == expected_dict -def test_from_dict(mock_boto3_session): +@pytest.mark.parametrize( + "boto3_config", + [ + None, + { + "read_timeout": 1000, + }, + ], +) +def test_from_dict(mock_boto3_session: Any, boto3_config: Optional[Dict[str, Any]]): """ Test that the from_dict method returns the correct object """ @@ -58,16 +78,14 @@ def test_from_dict(mock_boto3_session): "aws_profile_name": {"type": "env_var", "env_vars": ["AWS_PROFILE"], "strict": False}, "model": "anthropic.claude-v2", "max_length": 99, - "boto3_config": { - "read_timeout": 1000, - }, + "boto3_config": boto3_config, }, } ) assert generator.max_length == 99 assert generator.model == "anthropic.claude-v2" - assert generator.boto3_config == {"read_timeout": 1000} + assert generator.boto3_config == boto3_config def test_default_constructor(mock_boto3_session, set_env_variables): diff --git a/integrations/amazon_bedrock/tests/test_text_embedder.py b/integrations/amazon_bedrock/tests/test_text_embedder.py index 4f4e92448..2518b5c5f 100644 --- a/integrations/amazon_bedrock/tests/test_text_embedder.py +++ b/integrations/amazon_bedrock/tests/test_text_embedder.py @@ -59,6 +59,7 @@ def test_to_dict(self, mock_boto3_session): "aws_profile_name": {"type": "env_var", "env_vars": ["AWS_PROFILE"], "strict": False}, "model": "cohere.embed-english-v3", "input_type": "search_query", + "boto3_config": None, }, } @@ -76,6 +77,9 @@ def test_from_dict(self, mock_boto3_session): "aws_profile_name": {"type": "env_var", "env_vars": ["AWS_PROFILE"], "strict": False}, "model": "cohere.embed-english-v3", "input_type": "search_query", + "boto3_config": { + "read_timeout": 1000, + }, }, } @@ -83,6 +87,7 @@ def test_from_dict(self, mock_boto3_session): assert embedder.model == "cohere.embed-english-v3" assert embedder.kwargs == {"input_type": "search_query"} + assert embedder.boto3_config == {"read_timeout": 1000} def test_init_invalid_model(self): with pytest.raises(ValueError): From 2c80a0b326e32064970492e60a5d0b550be62ad2 Mon Sep 17 00:00:00 2001 From: HaystackBot Date: Tue, 3 Dec 2024 16:06:59 +0000 Subject: [PATCH 122/229] Update the changelog --- integrations/amazon_bedrock/CHANGELOG.md | 78 +++++++++++++++++++++--- 1 file changed, 70 insertions(+), 8 deletions(-) diff --git a/integrations/amazon_bedrock/CHANGELOG.md b/integrations/amazon_bedrock/CHANGELOG.md index 1068e870a..8e4350423 100644 --- a/integrations/amazon_bedrock/CHANGELOG.md +++ b/integrations/amazon_bedrock/CHANGELOG.md @@ -1,27 +1,45 @@ # Changelog +## [integrations/amazon_bedrock-v1.1.1] - 2024-12-03 + +### 🐛 Bug Fixes + +- AmazonBedrockChatGenerator with Claude raises moot warning for stream… (#1205) +- Allow passing boto3 config to all AWS Bedrock classes (#1166) + +### 🧹 Chores + +- Fix linting/isort (#1215) + +### 🌀 Miscellaneous + +- Chore: use class methods to create `ChatMessage` (#1222) + ## [integrations/amazon_bedrock-v1.1.0] - 2024-10-23 ### 🚜 Refactor - Avoid downloading tokenizer if `truncate` is `False` (#1152) -### ⚙️ Miscellaneous Tasks +### ⚙️ CI - Adopt uv as installer (#1142) + ## [integrations/amazon_bedrock-v1.0.5] - 2024-10-17 ### 🚀 Features - Add prefixes to supported model patterns to allow cross region model ids (#1127) + ## [integrations/amazon_bedrock-v1.0.4] - 2024-10-16 ### 🐛 Bug Fixes - Avoid bedrock read timeout (add boto3_config param) (#1135) + ## [integrations/amazon_bedrock-v1.0.3] - 2024-10-04 ### 🐛 Bug Fixes @@ -33,10 +51,14 @@ - Remove usage of deprecated `ChatMessage.to_openai_format` (#1007) -### ⚙️ Miscellaneous Tasks +### 🧹 Chores - Update ruff linting scripts and settings (#1105) +### 🌀 Miscellaneous + +- Modify regex to allow cross-region inference in bedrock (#1120) + ## [integrations/amazon_bedrock-v1.0.1] - 2024-08-19 ### 🚀 Features @@ -47,6 +69,7 @@ - Normalising ChatGenerators output (#973) + ## [integrations/amazon_bedrock-v1.0.0] - 2024-08-12 ### 🚜 Refactor @@ -57,13 +80,14 @@ - Do not retry tests in `hatch run test` command (#954) + ## [integrations/amazon_bedrock-v0.10.0] - 2024-08-12 ### 🐛 Bug Fixes - Support streaming_callback param in amazon bedrock generators (#927) -### Docs +### 🌀 Miscellaneous - Update AmazonBedrockChatGenerator docstrings (#949) - Update AmazonBedrockGenerator docstrings (#956) @@ -75,11 +99,19 @@ - Use non-gated tokenizer as fallback for mistral in AmazonBedrockChatGenerator (#843) - Made truncation optional for BedrockGenerator (#833) -### ⚙️ Miscellaneous Tasks +### ⚙️ CI - Retry tests to reduce flakyness (#836) + +### 🧹 Chores + - Update ruff invocation to include check parameter (#853) +### 🌀 Miscellaneous + +- Ci: install `pytest-rerunfailures` where needed; add retry config to `test-cov` script (#845) +- Add meta deprecration warning (#910) + ## [integrations/amazon_bedrock-v0.9.0] - 2024-06-14 ### 🚀 Features @@ -96,8 +128,18 @@ - Max_tokens typo in Mistral Chat (#740) +### 🌀 Miscellaneous + +- Chore: change the pydoc renderer class (#718) +- Adding support of "amazon.titan-embed-text-v2:0" (#735) + ## [integrations/amazon_bedrock-v0.7.1] - 2024-04-24 +### 🌀 Miscellaneous + +- Chore: add license classifiers (#680) +- Fix: Fix streaming_callback serialization in AmazonBedrockChatGenerator (#685) + ## [integrations/amazon_bedrock-v0.7.0] - 2024-04-16 ### 🚀 Features @@ -108,6 +150,11 @@ - Disable-class-def (#556) +### 🌀 Miscellaneous + +- Remove references to Python 3.7 (#601) +- [Bedrock] Added Amazon Bedrock examples (#635) + ## [integrations/amazon_bedrock-v0.6.0] - 2024-03-11 ### 🚀 Features @@ -119,6 +166,10 @@ - Small consistency improvements (#536) - Review integrations bedrock (#550) +### 🌀 Miscellaneous + +- Docs updates + two additional unit tests (#513) + ## [integrations/amazon_bedrock-v0.5.1] - 2024-02-22 ### 🚀 Features @@ -129,20 +180,27 @@ - Fix order of API docs (#447) -This PR will also push the docs to Readme - ### 📚 Documentation - Update category slug (#442) -### ⚙️ Miscellaneous Tasks +### 🧹 Chores - Update Amazon Bedrock integration to use new generic callable (de)serializers for their callback handlers (#452) - Use `serialize_callable` instead of `serialize_callback_handler` in Bedrock (#459) +### 🌀 Miscellaneous + +- Amazon bedrock: generate api docs (#326) +- Adopt Secret to Amazon Bedrock (#416) +- Bedrock - remove `supports` method (#456) +- Bedrock refactoring (#455) +- Bedrock Text Embedder (#466) +- Bedrock Document Embedder (#468) + ## [integrations/amazon_bedrock-v0.3.0] - 2024-01-30 -### ⚙️ Miscellaneous Tasks +### 🧹 Chores - [**breaking**] Rename `model_name` to `model` in `AmazonBedrockGenerator` (#220) - Amazon Bedrock subproject refactoring (#293) @@ -150,4 +208,8 @@ This PR will also push the docs to Readme ## [integrations/amazon_bedrock-v0.1.0] - 2024-01-03 +### 🌀 Miscellaneous + +- [Amazon Bedrock] Add AmazonBedrockGenerator (#153) + From 1959ab16cd50f3cd922e8ad83dddb3a7662d8722 Mon Sep 17 00:00:00 2001 From: Michele Pangrazzi Date: Thu, 5 Dec 2024 14:20:39 +0100 Subject: [PATCH 123/229] Mongodb keyword search (#1228) * feat: add full-text search capability * feat: add full-text retriever * docs: update docs for mongodb atlas indexes * docs: update usage example for document store * feat: update embedding retrieval example * feat: add hybrid retrieval example * fix: correct typo for parameter name * test: add full-text retrieval test * test: add test for full-text aggregation pipeline * tested examples ; minor refactor adding prints * fix lint * fix test * update test * fix lint * fix fulltext_retriever tests * update workflow to set MONGO_CONNECTION_STRING_2 in env --------- Co-authored-by: kanenorman Co-authored-by: Kane Norman <51185594+kanenorman@users.noreply.github.com> Co-authored-by: Vladimir Blagojevic --- .github/workflows/mongodb_atlas.yml | 13 +- .../{example.py => embedding_retrieval.py} | 16 +- .../examples/hybrid_retrieval.py | 80 +++++++ .../retrievers/mongodb_atlas/__init__.py | 3 +- .../mongodb_atlas/embedding_retriever.py | 3 +- .../mongodb_atlas/full_text_retriever.py | 150 ++++++++++++ .../mongodb_atlas/document_store.py | 118 +++++++++- .../tests/test_document_store.py | 5 + .../tests/test_embedding_retrieval.py | 8 +- .../tests/test_fulltext_retrieval.py | 147 ++++++++++++ .../mongodb_atlas/tests/test_retriever.py | 219 +++++++++++++++++- 11 files changed, 744 insertions(+), 18 deletions(-) rename integrations/mongodb_atlas/examples/{example.py => embedding_retrieval.py} (80%) create mode 100644 integrations/mongodb_atlas/examples/hybrid_retrieval.py create mode 100644 integrations/mongodb_atlas/src/haystack_integrations/components/retrievers/mongodb_atlas/full_text_retriever.py create mode 100644 integrations/mongodb_atlas/tests/test_fulltext_retrieval.py diff --git a/.github/workflows/mongodb_atlas.yml b/.github/workflows/mongodb_atlas.yml index 3d1ad5101..3fd2a43ac 100644 --- a/.github/workflows/mongodb_atlas.yml +++ b/.github/workflows/mongodb_atlas.yml @@ -4,11 +4,11 @@ name: Test / mongodb_atlas on: schedule: - - cron: "0 0 * * *" + - cron: '0 0 * * *' pull_request: paths: - - "integrations/mongodb_atlas/**" - - ".github/workflows/mongodb_atlas.yml" + - 'integrations/mongodb_atlas/**' + - '.github/workflows/mongodb_atlas.yml' defaults: run: @@ -19,9 +19,10 @@ concurrency: cancel-in-progress: true env: - PYTHONUNBUFFERED: "1" - FORCE_COLOR: "1" + PYTHONUNBUFFERED: '1' + FORCE_COLOR: '1' MONGO_CONNECTION_STRING: ${{ secrets.MONGO_CONNECTION_STRING }} + MONGO_CONNECTION_STRING_2: ${{ secrets.MONGO_CONNECTION_STRING_2 }} jobs: run: @@ -31,7 +32,7 @@ jobs: fail-fast: false matrix: os: [ubuntu-latest, windows-latest, macos-latest] - python-version: ["3.9", "3.10", "3.11"] + python-version: ['3.9', '3.10', '3.11'] steps: - uses: actions/checkout@v4 diff --git a/integrations/mongodb_atlas/examples/example.py b/integrations/mongodb_atlas/examples/embedding_retrieval.py similarity index 80% rename from integrations/mongodb_atlas/examples/example.py rename to integrations/mongodb_atlas/examples/embedding_retrieval.py index 54fd569ce..d8a71c343 100644 --- a/integrations/mongodb_atlas/examples/example.py +++ b/integrations/mongodb_atlas/examples/embedding_retrieval.py @@ -19,6 +19,8 @@ # To use the MongoDBAtlasDocumentStore, you must have a running MongoDB Atlas database. # For details, see https://www.mongodb.com/docs/atlas/getting-started/ +# NOTE: you need to create manually the vector search index and the full text search +# index in your MongoDB Atlas database. # Once your database is set, set the environment variable `MONGO_CONNECTION_STRING` # with the connection string to your MongoDB Atlas database. @@ -29,12 +31,17 @@ database_name="haystack_test", collection_name="test_collection", vector_search_index="test_vector_search_index", + full_text_search_index="test_full_text_search_index", ) +# This is to avoid duplicates in the collection +print(f"Cleaning up collection {document_store.collection_name}") +document_store.collection.delete_many({}) + # Create the indexing Pipeline and index some documents file_paths = glob.glob("neural-search-pills/pills/*.md") - +print("Creating indexing pipeline") indexing = Pipeline() indexing.add_component("converter", MarkdownToDocument()) indexing.add_component("splitter", DocumentSplitter(split_by="sentence", split_length=2)) @@ -44,17 +51,20 @@ indexing.connect("splitter", "embedder") indexing.connect("embedder", "writer") +print(f"Running indexing pipeline with {len(file_paths)} files") indexing.run({"converter": {"sources": file_paths}}) - -# Create the querying Pipeline and try a query +print("Creating querying pipeline") querying = Pipeline() querying.add_component("embedder", SentenceTransformersTextEmbedder()) querying.add_component("retriever", MongoDBAtlasEmbeddingRetriever(document_store=document_store, top_k=3)) querying.connect("embedder", "retriever") +query = "What is a cross-encoder?" +print(f"Running querying pipeline with query: '{query}'") results = querying.run({"embedder": {"text": "What is a cross-encoder?"}}) +print(f"Results: {results}") for doc in results["retriever"]["documents"]: print(doc) print("-" * 10) diff --git a/integrations/mongodb_atlas/examples/hybrid_retrieval.py b/integrations/mongodb_atlas/examples/hybrid_retrieval.py new file mode 100644 index 000000000..a165edf12 --- /dev/null +++ b/integrations/mongodb_atlas/examples/hybrid_retrieval.py @@ -0,0 +1,80 @@ +# Install required packages for this example, including mongodb-atlas-haystack and other libraries needed +# for Markdown conversion and embeddings generation. Use the following command: +# +# pip install mongodb-atlas-haystack markdown-it-py mdit_plain "sentence-transformers>=2.2.0" +# +# Download some Markdown files to index. +# git clone https://github.com/anakin87/neural-search-pills + +import glob + +from haystack import Pipeline +from haystack.components.converters import MarkdownToDocument +from haystack.components.embedders import SentenceTransformersDocumentEmbedder, SentenceTransformersTextEmbedder +from haystack.components.joiners import DocumentJoiner +from haystack.components.preprocessors import DocumentSplitter +from haystack.components.writers import DocumentWriter + +from haystack_integrations.components.retrievers.mongodb_atlas import ( + MongoDBAtlasEmbeddingRetriever, + MongoDBAtlasFullTextRetriever, +) +from haystack_integrations.document_stores.mongodb_atlas import MongoDBAtlasDocumentStore + +# To use the MongoDBAtlasDocumentStore, you must have a running MongoDB Atlas database. +# For details, see https://www.mongodb.com/docs/atlas/getting-started/ +# NOTE: you need to create manually the vector search index and the full text search +# index in your MongoDB Atlas database. + +# Once your database is set, set the environment variable `MONGO_CONNECTION_STRING` +# with the connection string to your MongoDB Atlas database. +# format: "mongodb+srv://{mongo_atlas_username}:{mongo_atlas_password}@{mongo_atlas_host}/?{mongo_atlas_params_string}". + +# Initialize the document store +document_store = MongoDBAtlasDocumentStore( + database_name="haystack_test", + collection_name="test_collection", + vector_search_index="test_vector_search_index", + full_text_search_index="test_full_text_search_index", +) + +file_paths = glob.glob("neural-search-pills/pills/*.md") + +# This is to avoid duplicates in the collection +print(f"Cleaning up collection {document_store.collection_name}") +document_store.collection.delete_many({}) + +print("Creating indexing pipeline") +indexing = Pipeline() +indexing.add_component("converter", MarkdownToDocument()) +indexing.add_component("splitter", DocumentSplitter(split_by="sentence", split_length=2)) +indexing.add_component("document_embedder", SentenceTransformersDocumentEmbedder()) +indexing.add_component("writer", DocumentWriter(document_store)) +indexing.connect("converter", "splitter") +indexing.connect("splitter", "document_embedder") +indexing.connect("document_embedder", "writer") + +print(f"Running indexing pipeline with {len(file_paths)} files") +indexing.run({"converter": {"sources": file_paths}}) + +print("Creating querying pipeline") +querying = Pipeline() +querying.add_component("text_embedder", SentenceTransformersTextEmbedder()) +querying.add_component("embedding_retriever", MongoDBAtlasEmbeddingRetriever(document_store=document_store, top_k=3)) +querying.add_component("full_text_retriever", MongoDBAtlasFullTextRetriever(document_store=document_store, top_k=3)) +querying.add_component( + "joiner", + DocumentJoiner(join_mode="reciprocal_rank_fusion", top_k=3), +) +querying.connect("text_embedder", "embedding_retriever") +querying.connect("embedding_retriever", "joiner") +querying.connect("full_text_retriever", "joiner") + +query = "cross-encoder" +print(f"Running querying pipeline with query '{query}'") +results = querying.run({"text_embedder": {"text": query}, "full_text_retriever": {"query": query}}) + +print(f"Results: {results}") +for doc in results["joiner"]["documents"]: + print(doc) + print("-" * 10) diff --git a/integrations/mongodb_atlas/src/haystack_integrations/components/retrievers/mongodb_atlas/__init__.py b/integrations/mongodb_atlas/src/haystack_integrations/components/retrievers/mongodb_atlas/__init__.py index fed0a4c28..bbeec63d1 100644 --- a/integrations/mongodb_atlas/src/haystack_integrations/components/retrievers/mongodb_atlas/__init__.py +++ b/integrations/mongodb_atlas/src/haystack_integrations/components/retrievers/mongodb_atlas/__init__.py @@ -1,3 +1,4 @@ from haystack_integrations.components.retrievers.mongodb_atlas.embedding_retriever import MongoDBAtlasEmbeddingRetriever +from haystack_integrations.components.retrievers.mongodb_atlas.full_text_retriever import MongoDBAtlasFullTextRetriever -__all__ = ["MongoDBAtlasEmbeddingRetriever"] +__all__ = ["MongoDBAtlasEmbeddingRetriever", "MongoDBAtlasFullTextRetriever"] diff --git a/integrations/mongodb_atlas/src/haystack_integrations/components/retrievers/mongodb_atlas/embedding_retriever.py b/integrations/mongodb_atlas/src/haystack_integrations/components/retrievers/mongodb_atlas/embedding_retriever.py index 3345f4f7c..4579a85bc 100644 --- a/integrations/mongodb_atlas/src/haystack_integrations/components/retrievers/mongodb_atlas/embedding_retriever.py +++ b/integrations/mongodb_atlas/src/haystack_integrations/components/retrievers/mongodb_atlas/embedding_retriever.py @@ -28,7 +28,8 @@ class MongoDBAtlasEmbeddingRetriever: store = MongoDBAtlasDocumentStore(database_name="haystack_integration_test", collection_name="test_embeddings_collection", - vector_search_index="cosine_index") + vector_search_index="cosine_index", + full_text_search_index="full_text_index") retriever = MongoDBAtlasEmbeddingRetriever(document_store=store) results = retriever.run(query_embedding=np.random.random(768).tolist()) diff --git a/integrations/mongodb_atlas/src/haystack_integrations/components/retrievers/mongodb_atlas/full_text_retriever.py b/integrations/mongodb_atlas/src/haystack_integrations/components/retrievers/mongodb_atlas/full_text_retriever.py new file mode 100644 index 000000000..63348c6f3 --- /dev/null +++ b/integrations/mongodb_atlas/src/haystack_integrations/components/retrievers/mongodb_atlas/full_text_retriever.py @@ -0,0 +1,150 @@ +# SPDX-FileCopyrightText: 2023-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 +from typing import Any, Dict, List, Literal, Optional, Union + +from haystack import component, default_from_dict, default_to_dict +from haystack.dataclasses import Document +from haystack.document_stores.types import FilterPolicy +from haystack.document_stores.types.filter_policy import apply_filter_policy + +from haystack_integrations.document_stores.mongodb_atlas import MongoDBAtlasDocumentStore + + +@component +class MongoDBAtlasFullTextRetriever: + """ + Retrieves documents from the MongoDBAtlasDocumentStore by full-text search. + + The full-text search is dependent on the full_text_search_index used in the MongoDBAtlasDocumentStore. + See MongoDBAtlasDocumentStore for more information. + + Usage example: + ```python + from haystack_integrations.document_stores.mongodb_atlas import MongoDBAtlasDocumentStore + from haystack_integrations.components.retrievers.mongodb_atlas import MongoDBAtlasFullTextRetriever + + store = MongoDBAtlasDocumentStore(database_name="your_existing_db", + collection_name="your_existing_collection", + vector_search_index="your_existing_index", + full_text_search_index="your_existing_index") + retriever = MongoDBAtlasFullTextRetriever(document_store=store) + + results = retriever.run(query="Lorem ipsum") + print(results["documents"]) + ``` + + The example above retrieves the 10 most similar documents to the query "Lorem ipsum" from the + MongoDBAtlasDocumentStore. + """ + + def __init__( + self, + *, + document_store: MongoDBAtlasDocumentStore, + filters: Optional[Dict[str, Any]] = None, + top_k: int = 10, + filter_policy: Union[str, FilterPolicy] = FilterPolicy.REPLACE, + ): + """ + :param document_store: An instance of MongoDBAtlasDocumentStore. + :param filters: Filters applied to the retrieved Documents. Make sure that the fields used in the filters are + included in the configuration of the `full_text_search_index`. The configuration must be done manually + in the Web UI of MongoDB Atlas. + :param top_k: Maximum number of Documents to return. + :param filter_policy: Policy to determine how filters are applied. + + :raises ValueError: If `document_store` is not an instance of MongoDBAtlasDocumentStore. + """ + + if not isinstance(document_store, MongoDBAtlasDocumentStore): + msg = "document_store must be an instance of MongoDBAtlasDocumentStore" + raise ValueError(msg) + + self.document_store = document_store + self.filters = filters or {} + self.top_k = top_k + self.filter_policy = ( + filter_policy if isinstance(filter_policy, FilterPolicy) else FilterPolicy.from_str(filter_policy) + ) + + def to_dict(self) -> Dict[str, Any]: + """ + Serializes the component to a dictionary. + + :returns: + Dictionary with serialized data. + """ + return default_to_dict( + self, + filters=self.filters, + top_k=self.top_k, + filter_policy=self.filter_policy.value, + document_store=self.document_store.to_dict(), + ) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "MongoDBAtlasFullTextRetriever": + """ + Deserializes the component from a dictionary. + + :param data: + Dictionary to deserialize from. + :returns: + Deserialized component. + """ + data["init_parameters"]["document_store"] = MongoDBAtlasDocumentStore.from_dict( + data["init_parameters"]["document_store"] + ) + + return default_from_dict(cls, data) + + @component.output_types(documents=List[Document]) + def run( + self, + query: Union[str, List[str]], + fuzzy: Optional[Dict[str, int]] = None, + match_criteria: Optional[Literal["any", "all"]] = None, + score: Optional[Dict[str, Dict]] = None, + synonyms: Optional[str] = None, + filters: Optional[Dict[str, Any]] = None, + top_k: int = 10, + ) -> Dict[str, List[Document]]: + """ + Retrieve documents from the MongoDBAtlasDocumentStore by full-text search. + + :param query: The query string or a list of query strings to search for. + If the query contains multiple terms, Atlas Search evaluates each term separately for matches. + :param fuzzy: Enables finding strings similar to the search term(s). + Note, `fuzzy` cannot be used with `synonyms`. Configurable options include `maxEdits`, `prefixLength`, + and `maxExpansions`. For more details refer to MongoDB Atlas + [documentation](https://www.mongodb.com/docs/atlas/atlas-search/text/#fields). + :param match_criteria: Defines how terms in the query are matched. Supported options are `"any"` and `"all"`. + For more details refer to MongoDB Atlas + [documentation](https://www.mongodb.com/docs/atlas/atlas-search/text/#fields). + :param score: Specifies the scoring method for matching results. Supported options include `boost`, `constant`, + and `function`. For more details refer to MongoDB Atlas + [documentation](https://www.mongodb.com/docs/atlas/atlas-search/text/#fields). + :param synonyms: The name of the synonym mapping definition in the index. This value cannot be an empty string. + Note, `synonyms` can not be used with `fuzzy`. + :param filters: Filters applied to the retrieved Documents. The way runtime filters are applied depends on + the `filter_policy` chosen at retriever initialization. See init method docstring for more + details. + :param top_k: Maximum number of Documents to return. Overrides the value specified at initialization. + :returns: A dictionary with the following keys: + - `documents`: List of Documents most similar to the given `query` + """ + filters = apply_filter_policy(self.filter_policy, self.filters, filters) + top_k = top_k or self.top_k + + docs = self.document_store._fulltext_retrieval( + query=query, + fuzzy=fuzzy, + match_criteria=match_criteria, + score=score, + synonyms=synonyms, + filters=filters, + top_k=top_k, + ) + + return {"documents": docs} diff --git a/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/document_store.py b/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/document_store.py index 79caa15f8..f13924185 100644 --- a/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/document_store.py +++ b/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/document_store.py @@ -3,7 +3,7 @@ # SPDX-License-Identifier: Apache-2.0 import logging import re -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Literal, Optional, Union from haystack import default_from_dict, default_to_dict from haystack.dataclasses.document import Document @@ -37,8 +37,10 @@ class MongoDBAtlasDocumentStore: Python driver. Creating databases and collections is beyond the scope of MongoDBAtlasDocumentStore. The primary purpose of this document store is to read and write documents to an existing collection. - The last parameter users needs to provide is a `vector_search_index` - used for vector search operations. This index - can support a chosen metric (i.e. cosine, dot product, or euclidean) and can be created in the Atlas web UI. + Users must provide both a `vector_search_index` for vector search operations and a `full_text_search_index` + for full-text search operations. The `vector_search_index` supports a chosen metric + (e.g., cosine, dot product, or Euclidean), while the `full_text_search_index` enables efficient text-based searches. + Both indexes can be created through the Atlas web UI. For more details on MongoDB Atlas, see the official MongoDB Atlas [documentation](https://www.mongodb.com/docs/atlas/getting-started/). @@ -49,7 +51,8 @@ class MongoDBAtlasDocumentStore: store = MongoDBAtlasDocumentStore(database_name="your_existing_db", collection_name="your_existing_collection", - vector_search_index="your_existing_index") + vector_search_index="your_existing_index", + full_text_search_index="your_existing_index") print(store.count_documents()) ``` """ @@ -61,6 +64,7 @@ def __init__( database_name: str, collection_name: str, vector_search_index: str, + full_text_search_index: str, ): """ Creates a new MongoDBAtlasDocumentStore instance. @@ -76,6 +80,10 @@ def __init__( Create a vector_search_index in the Atlas web UI and specify the init params of MongoDBAtlasDocumentStore. \ For more details refer to MongoDB Atlas [documentation](https://www.mongodb.com/docs/atlas/atlas-vector-search/create-index/#std-label-avs-create-index). + :param full_text_search_index: The name of the search index to use for full-text search operations. + Create a full_text_search_index in the Atlas web UI and specify the init params of + MongoDBAtlasDocumentStore. For more details refer to MongoDB Atlas + [documentation](https://www.mongodb.com/docs/atlas/atlas-search/create-index/). :raises ValueError: If the collection name contains invalid characters. """ @@ -88,6 +96,7 @@ def __init__( self.database_name = database_name self.collection_name = collection_name self.vector_search_index = vector_search_index + self.full_text_search_index = full_text_search_index self._connection: Optional[MongoClient] = None self._collection: Optional[Collection] = None @@ -124,6 +133,7 @@ def to_dict(self) -> Dict[str, Any]: database_name=self.database_name, collection_name=self.collection_name, vector_search_index=self.vector_search_index, + full_text_search_index=self.full_text_search_index, ) @classmethod @@ -285,6 +295,106 @@ def _embedding_retrieval( documents = [self._mongo_doc_to_haystack_doc(doc) for doc in documents] return documents + def _fulltext_retrieval( + self, + query: Union[str, List[str]], + fuzzy: Optional[Dict[str, int]] = None, + match_criteria: Optional[Literal["any", "all"]] = None, + score: Optional[Dict[str, Dict]] = None, + synonyms: Optional[str] = None, + filters: Optional[Dict[str, Any]] = None, + top_k: int = 10, + ) -> List[Document]: + """ + Retrieve documents similar to the provided `query` using a full-text search. + + :param query: The query string or a list of query strings to search for. + If the query contains multiple terms, Atlas Search evaluates each term separately for matches. + :param fuzzy: Enables finding strings similar to the search term(s). + Note, `fuzzy` cannot be used with `synonyms`. Configurable options include `maxEdits`, `prefixLength`, + and `maxExpansions`. For more details refer to MongoDB Atlas + [documentation](https://www.mongodb.com/docs/atlas/atlas-search/text/#fields). + :param match_criteria: Defines how terms in the query are matched. Supported options are `"any"` and `"all"`. + For more details refer to MongoDB Atlas + [documentation](https://www.mongodb.com/docs/atlas/atlas-search/text/#fields). + :param score: Specifies the scoring method for matching results. Supported options include `boost`, `constant`, + and `function`. For more details refer to MongoDB Atlas + [documentation](https://www.mongodb.com/docs/atlas/atlas-search/text/#fields). + :param synonyms: The name of the synonym mapping definition in the index. This value cannot be an empty string. + Note, `synonyms` can not be used with `fuzzy`. + :param filters: Optional filters. + :param top_k: How many documents to return. + :returns: A list of Documents that are most similar to the given `query` + :raises ValueError: If `query` or `synonyms` is empty. + :raises ValueError: If `synonyms` and `fuzzy` are used together. + :raises DocumentStoreError: If the retrieval of documents from MongoDB Atlas fails. + """ + # Validate user input according to MongoDB Atlas Search requirements + if not query: + msg = "Argument query must not be empty." + raise ValueError(msg) + + if isinstance(synonyms, str) and not synonyms: + msg = "Argument synonyms cannot be an empty string." + raise ValueError(msg) + + if synonyms and fuzzy: + msg = "Cannot use both synonyms and fuzzy search together." + raise ValueError(msg) + + if synonyms and not match_criteria: + logger.warning( + "Specify matchCriteria when using synonyms. " + "Atlas Search matches terms in exact order by default, which may change in future versions." + ) + + filters = _normalize_filters(filters) if filters else {} + + # Build the text search options + text_search: Dict[str, Any] = {"path": "content", "query": query} + if match_criteria: + text_search["matchCriteria"] = match_criteria + if synonyms: + text_search["synonyms"] = synonyms + if fuzzy: + text_search["fuzzy"] = fuzzy + if score: + text_search["score"] = score + + # Define the pipeline for MongoDB aggregation + pipeline = [ + { + "$search": { + "index": self.full_text_search_index, + "compound": {"must": [{"text": text_search}]}, + } + }, + # TODO: Use compound filter. See: (https://www.mongodb.com/docs/atlas/atlas-search/performance/query-performance/#avoid--match-after--search) + {"$match": filters}, + {"$limit": top_k}, + { + "$project": { + "_id": 0, + "content": 1, + "dataframe": 1, + "blob": 1, + "meta": 1, + "embedding": 1, + "score": {"$meta": "searchScore"}, + } + }, + ] + + try: + documents = list(self.collection.aggregate(pipeline)) + except Exception as e: + error_msg = f"Failed to retrieve documents from MongoDB Atlas: {e}" + if filters: + error_msg += "\nEnsure fields in filters are included in the `full_text_search_index` configuration." + raise DocumentStoreError(error_msg) from e + + return [self._mongo_doc_to_haystack_doc(doc) for doc in documents] + def _mongo_doc_to_haystack_doc(self, mongo_doc: Dict[str, Any]) -> Document: """ Converts the dictionary coming out of MongoDB into a Haystack document diff --git a/integrations/mongodb_atlas/tests/test_document_store.py b/integrations/mongodb_atlas/tests/test_document_store.py index 6d34b1ca0..6c0ac191e 100644 --- a/integrations/mongodb_atlas/tests/test_document_store.py +++ b/integrations/mongodb_atlas/tests/test_document_store.py @@ -25,6 +25,7 @@ def test_init_is_lazy(_mock_client): database_name="database_name", collection_name="collection_name", vector_search_index="cosine_index", + full_text_search_index="full_text_index", ) _mock_client.assert_not_called() @@ -53,6 +54,7 @@ def document_store(self): database_name=database_name, collection_name=collection_name, vector_search_index="cosine_index", + full_text_search_index="full_text_index", ) yield store database[collection_name].drop() @@ -92,6 +94,7 @@ def test_to_dict(self, document_store): }, "database_name": "haystack_integration_test", "vector_search_index": "cosine_index", + "full_text_search_index": "full_text_index", }, } @@ -110,6 +113,7 @@ def test_from_dict(self): "database_name": "haystack_integration_test", "collection_name": "test_embeddings_collection", "vector_search_index": "cosine_index", + "full_text_search_index": "full_text_index", }, } ) @@ -117,6 +121,7 @@ def test_from_dict(self): assert docstore.database_name == "haystack_integration_test" assert docstore.collection_name == "test_embeddings_collection" assert docstore.vector_search_index == "cosine_index" + assert docstore.full_text_search_index == "full_text_index" def test_complex_filter(self, document_store, filterable_docs): document_store.write_documents(filterable_docs) diff --git a/integrations/mongodb_atlas/tests/test_embedding_retrieval.py b/integrations/mongodb_atlas/tests/test_embedding_retrieval.py index 143f6e106..306e59a98 100644 --- a/integrations/mongodb_atlas/tests/test_embedding_retrieval.py +++ b/integrations/mongodb_atlas/tests/test_embedding_retrieval.py @@ -21,11 +21,12 @@ def test_embedding_retrieval_cosine_similarity(self): database_name="haystack_integration_test", collection_name="test_embeddings_collection", vector_search_index="cosine_index", + full_text_search_index="full_text_index", ) query_embedding = [0.1] * 768 results = document_store._embedding_retrieval(query_embedding=query_embedding, top_k=2, filters={}) assert len(results) == 2 - assert results[0].content == "Document A" + assert results[0].content == "Document C" assert results[1].content == "Document B" assert results[0].score > results[1].score @@ -34,6 +35,7 @@ def test_embedding_retrieval_dot_product(self): database_name="haystack_integration_test", collection_name="test_embeddings_collection", vector_search_index="dotProduct_index", + full_text_search_index="full_text_index", ) query_embedding = [0.1] * 768 results = document_store._embedding_retrieval(query_embedding=query_embedding, top_k=2, filters={}) @@ -47,6 +49,7 @@ def test_embedding_retrieval_euclidean(self): database_name="haystack_integration_test", collection_name="test_embeddings_collection", vector_search_index="euclidean_index", + full_text_search_index="full_text_index", ) query_embedding = [0.1] * 768 results = document_store._embedding_retrieval(query_embedding=query_embedding, top_k=2, filters={}) @@ -60,6 +63,7 @@ def test_empty_query_embedding(self): database_name="haystack_integration_test", collection_name="test_embeddings_collection", vector_search_index="cosine_index", + full_text_search_index="full_text_index", ) query_embedding: List[float] = [] with pytest.raises(ValueError): @@ -70,6 +74,7 @@ def test_query_embedding_wrong_dimension(self): database_name="haystack_integration_test", collection_name="test_embeddings_collection", vector_search_index="cosine_index", + full_text_search_index="full_text_index", ) query_embedding = [0.1] * 4 with pytest.raises(DocumentStoreError): @@ -98,6 +103,7 @@ def test_embedding_retrieval_with_filters(self): database_name="haystack_integration_test", collection_name="test_embeddings_collection", vector_search_index="cosine_index", + full_text_search_index="full_text_index", ) query_embedding = [0.1] * 768 filters = {"field": "content", "operator": "!=", "value": "Document A"} diff --git a/integrations/mongodb_atlas/tests/test_fulltext_retrieval.py b/integrations/mongodb_atlas/tests/test_fulltext_retrieval.py new file mode 100644 index 000000000..aa0132f2c --- /dev/null +++ b/integrations/mongodb_atlas/tests/test_fulltext_retrieval.py @@ -0,0 +1,147 @@ +# SPDX-FileCopyrightText: 2023-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 +import os +from time import sleep +from typing import List, Union +from unittest.mock import MagicMock + +import pytest +from haystack import Document +from haystack.utils import Secret + +from haystack_integrations.document_stores.mongodb_atlas import MongoDBAtlasDocumentStore + + +def get_document_store(): + return MongoDBAtlasDocumentStore( + mongo_connection_string=Secret.from_env_var("MONGO_CONNECTION_STRING_2"), + database_name="haystack_test", + collection_name="test_collection", + vector_search_index="cosine_index", + full_text_search_index="full_text_index", + ) + + +@pytest.mark.skipif( + "MONGO_CONNECTION_STRING_2" not in os.environ, + reason="No MongoDB Atlas connection string provided", +) +@pytest.mark.integration +class TestFullTextRetrieval: + @pytest.fixture(scope="class") + def document_store(self) -> MongoDBAtlasDocumentStore: + return get_document_store() + + @pytest.fixture(autouse=True, scope="class") + def setup_teardown(self, document_store): + document_store.collection.delete_many({}) + document_store.write_documents( + [ + Document(content="The quick brown fox chased the dog", meta={"meta_field": "right_value"}), + Document(content="The fox was brown", meta={"meta_field": "right_value"}), + Document(content="The lazy dog"), + Document(content="fox fox fox"), + ] + ) + + # Wait for documents to be indexed + sleep(5) + + yield + + def test_pipeline_correctly_passes_parameters(self): + document_store = get_document_store() + mock_collection = MagicMock() + document_store._collection = mock_collection + mock_collection.aggregate.return_value = [] + document_store._fulltext_retrieval( + query=["spam", "eggs"], + fuzzy={"maxEdits": 1}, + match_criteria="any", + score={"boost": {"value": 3}}, + filters={"field": "meta.meta_field", "operator": "==", "value": "right_value"}, + top_k=5, + ) + + # Assert aggregate was called with the correct pipeline + assert mock_collection.aggregate.called + actual_pipeline = mock_collection.aggregate.call_args[0][0] + expected_pipeline = [ + { + "$search": { + "compound": { + "must": [ + { + "text": { + "fuzzy": {"maxEdits": 1}, + "matchCriteria": "any", + "path": "content", + "query": ["spam", "eggs"], + "score": {"boost": {"value": 3}}, + } + } + ] + }, + "index": "full_text_index", + } + }, + {"$match": {"meta.meta_field": {"$eq": "right_value"}}}, + {"$limit": 5}, + { + "$project": { + "_id": 0, + "blob": 1, + "content": 1, + "dataframe": 1, + "embedding": 1, + "meta": 1, + "score": {"$meta": "searchScore"}, + } + }, + ] + + assert actual_pipeline == expected_pipeline + + def test_query_retrieval(self, document_store: MongoDBAtlasDocumentStore): + results = document_store._fulltext_retrieval(query="fox", top_k=2) + assert len(results) == 2 + for doc in results: + assert "fox" in doc.content + assert results[0].score >= results[1].score + + def test_fuzzy_retrieval(self, document_store: MongoDBAtlasDocumentStore): + results = document_store._fulltext_retrieval(query="fax", fuzzy={"maxEdits": 1}, top_k=2) + assert len(results) == 2 + for doc in results: + assert "fox" in doc.content + assert results[0].score >= results[1].score + + def test_filters_retrieval(self, document_store: MongoDBAtlasDocumentStore): + filters = {"field": "meta.meta_field", "operator": "==", "value": "right_value"} + + results = document_store._fulltext_retrieval(query="fox", top_k=3, filters=filters) + assert len(results) == 2 + for doc in results: + assert "fox" in doc.content + assert doc.meta["meta_field"] == "right_value" + + def test_synonyms_retrieval(self, document_store: MongoDBAtlasDocumentStore): + results = document_store._fulltext_retrieval(query="reynard", synonyms="synonym_mapping", top_k=2) + assert len(results) == 2 + for doc in results: + assert "fox" in doc.content + assert results[0].score >= results[1].score + + @pytest.mark.parametrize("query", ["", []]) + def test_empty_query_raises_value_error(self, query: Union[str, List], document_store: MongoDBAtlasDocumentStore): + with pytest.raises(ValueError): + document_store._fulltext_retrieval(query=query) + + def test_empty_synonyms_raises_value_error(self, document_store: MongoDBAtlasDocumentStore): + with pytest.raises(ValueError): + document_store._fulltext_retrieval(query="fox", synonyms="") + + def test_synonyms_and_fuzzy_raises_value_error(self, document_store: MongoDBAtlasDocumentStore): + with pytest.raises(ValueError): + document_store._fulltext_retrieval(query="fox", synonyms="wolf", fuzzy={"maxEdits": 1}) diff --git a/integrations/mongodb_atlas/tests/test_retriever.py b/integrations/mongodb_atlas/tests/test_retriever.py index 832256ccd..26079d145 100644 --- a/integrations/mongodb_atlas/tests/test_retriever.py +++ b/integrations/mongodb_atlas/tests/test_retriever.py @@ -8,11 +8,14 @@ from haystack.document_stores.types import FilterPolicy from haystack.utils.auth import EnvVarSecret -from haystack_integrations.components.retrievers.mongodb_atlas import MongoDBAtlasEmbeddingRetriever +from haystack_integrations.components.retrievers.mongodb_atlas import ( + MongoDBAtlasEmbeddingRetriever, + MongoDBAtlasFullTextRetriever, +) from haystack_integrations.document_stores.mongodb_atlas import MongoDBAtlasDocumentStore -class TestRetriever: +class TestEmbeddingRetriever: @pytest.fixture def mock_client(self): with patch( @@ -72,6 +75,7 @@ def test_to_dict(self, mock_client, monkeypatch): # noqa: ARG002 mock_client a database_name="haystack_integration_test", collection_name="test_embeddings_collection", vector_search_index="cosine_index", + full_text_search_index="full_text_index", ) retriever = MongoDBAtlasEmbeddingRetriever(document_store=document_store, filters={"field": "value"}, top_k=5) @@ -90,6 +94,7 @@ def test_to_dict(self, mock_client, monkeypatch): # noqa: ARG002 mock_client a "database_name": "haystack_integration_test", "collection_name": "test_embeddings_collection", "vector_search_index": "cosine_index", + "full_text_search_index": "full_text_index", }, }, "filters": {"field": "value"}, @@ -115,6 +120,7 @@ def test_from_dict(self, mock_client, monkeypatch): # noqa: ARG002 mock_client "database_name": "haystack_integration_test", "collection_name": "test_embeddings_collection", "vector_search_index": "cosine_index", + "full_text_search_index": "full_text_index", }, }, "filters": {"field": "value"}, @@ -131,6 +137,7 @@ def test_from_dict(self, mock_client, monkeypatch): # noqa: ARG002 mock_client assert document_store.database_name == "haystack_integration_test" assert document_store.collection_name == "test_embeddings_collection" assert document_store.vector_search_index == "cosine_index" + assert document_store.full_text_search_index == "full_text_index" assert retriever.filters == {"field": "value"} assert retriever.top_k == 5 assert retriever.filter_policy == FilterPolicy.REPLACE @@ -152,6 +159,7 @@ def test_from_dict_no_filter_policy(self, monkeypatch): # mock_client appears u "database_name": "haystack_integration_test", "collection_name": "test_embeddings_collection", "vector_search_index": "cosine_index", + "full_text_search_index": "full_text_index", }, }, "filters": {"field": "value"}, @@ -167,6 +175,7 @@ def test_from_dict_no_filter_policy(self, monkeypatch): # mock_client appears u assert document_store.database_name == "haystack_integration_test" assert document_store.collection_name == "test_embeddings_collection" assert document_store.vector_search_index == "cosine_index" + assert document_store.full_text_search_index == "full_text_index" assert retriever.filters == {"field": "value"} assert retriever.top_k == 5 assert retriever.filter_policy == FilterPolicy.REPLACE # defaults to REPLACE @@ -204,3 +213,209 @@ def test_run_merge_policy_filter(self): ) assert res == {"documents": [doc]} + + +class TestFullTextRetriever: + @pytest.fixture + def mock_client(self): + with patch( + "haystack_integrations.document_stores.mongodb_atlas.document_store.MongoClient" + ) as mock_mongo_client: + mock_connection = MagicMock() + mock_database = MagicMock() + mock_collection_names = MagicMock(return_value=["test_full_text_collection"]) + mock_database.list_collection_names = mock_collection_names + mock_connection.__getitem__.return_value = mock_database + mock_mongo_client.return_value = mock_connection + yield mock_mongo_client + + def test_init_default(self): + mock_store = Mock(spec=MongoDBAtlasDocumentStore) + retriever = MongoDBAtlasFullTextRetriever(document_store=mock_store) + assert retriever.document_store == mock_store + assert retriever.filters == {} + assert retriever.top_k == 10 + assert retriever.filter_policy == FilterPolicy.REPLACE + + retriever = MongoDBAtlasFullTextRetriever(document_store=mock_store, filter_policy="merge") + assert retriever.filter_policy == FilterPolicy.MERGE + + with pytest.raises(ValueError): + MongoDBAtlasFullTextRetriever(document_store=mock_store, filter_policy="wrong_policy") + + def test_init(self): + mock_store = Mock(spec=MongoDBAtlasDocumentStore) + retriever = MongoDBAtlasFullTextRetriever( + document_store=mock_store, + filters={"field": "meta.some_field", "operator": "==", "value": "SomeValue"}, + top_k=5, + ) + assert retriever.document_store == mock_store + assert retriever.filters == {"field": "meta.some_field", "operator": "==", "value": "SomeValue"} + assert retriever.top_k == 5 + assert retriever.filter_policy == FilterPolicy.REPLACE + + def test_init_filter_policy_merge(self): + mock_store = Mock(spec=MongoDBAtlasDocumentStore) + retriever = MongoDBAtlasFullTextRetriever( + document_store=mock_store, + filters={"field": "meta.some_field", "operator": "==", "value": "SomeValue"}, + top_k=5, + filter_policy=FilterPolicy.MERGE, + ) + assert retriever.document_store == mock_store + assert retriever.filters == {"field": "meta.some_field", "operator": "==", "value": "SomeValue"} + assert retriever.top_k == 5 + assert retriever.filter_policy == FilterPolicy.MERGE + + def test_to_dict(self, mock_client, monkeypatch): # noqa: ARG002 mock_client appears unused but is required + monkeypatch.setenv("MONGO_CONNECTION_STRING", "test_conn_str") + + document_store = MongoDBAtlasDocumentStore( + database_name="haystack_integration_test", + collection_name="test_full_text_collection", + vector_search_index="cosine_index", + full_text_search_index="full_text_index", + ) + + retriever = MongoDBAtlasFullTextRetriever(document_store=document_store, filters={"field": "value"}, top_k=5) + res = retriever.to_dict() + assert res == { + "type": "haystack_integrations.components.retrievers.mongodb_atlas.full_text_retriever.MongoDBAtlasFullTextRetriever", # noqa: E501 + "init_parameters": { + "document_store": { + "type": "haystack_integrations.document_stores.mongodb_atlas.document_store.MongoDBAtlasDocumentStore", # noqa: E501 + "init_parameters": { + "mongo_connection_string": { + "env_vars": ["MONGO_CONNECTION_STRING"], + "strict": True, + "type": "env_var", + }, + "database_name": "haystack_integration_test", + "collection_name": "test_full_text_collection", + "vector_search_index": "cosine_index", + "full_text_search_index": "full_text_index", + }, + }, + "filters": {"field": "value"}, + "top_k": 5, + "filter_policy": "replace", + }, + } + + def test_from_dict(self, mock_client, monkeypatch): # noqa: ARG002 mock_client appears unused but is required + monkeypatch.setenv("MONGO_CONNECTION_STRING", "test_conn_str") + + data = { + "type": "haystack_integrations.components.retrievers.mongodb_atlas.full_text_retriever.MongoDBAtlasFullTextRetriever", # noqa: E501 + "init_parameters": { + "document_store": { + "type": "haystack_integrations.document_stores.mongodb_atlas.document_store.MongoDBAtlasDocumentStore", # noqa: E501 + "init_parameters": { + "mongo_connection_string": { + "env_vars": ["MONGO_CONNECTION_STRING"], + "strict": True, + "type": "env_var", + }, + "database_name": "haystack_integration_test", + "collection_name": "test_full_text_collection", + "vector_search_index": "cosine_index", + "full_text_search_index": "full_text_index", + }, + }, + "filters": {"field": "value"}, + "top_k": 5, + "filter_policy": "replace", + }, + } + + retriever = MongoDBAtlasFullTextRetriever.from_dict(data) + document_store = retriever.document_store + + assert isinstance(document_store, MongoDBAtlasDocumentStore) + assert isinstance(document_store.mongo_connection_string, EnvVarSecret) + assert document_store.database_name == "haystack_integration_test" + assert document_store.collection_name == "test_full_text_collection" + assert document_store.vector_search_index == "cosine_index" + assert document_store.full_text_search_index == "full_text_index" + assert retriever.filters == {"field": "value"} + assert retriever.top_k == 5 + assert retriever.filter_policy == FilterPolicy.REPLACE + + def test_from_dict_no_filter_policy(self, monkeypatch): # mock_client appears unused but is required + monkeypatch.setenv("MONGO_CONNECTION_STRING", "test_conn_str") + + data = { + "type": "haystack_integrations.components.retrievers.mongodb_atlas.full_text_retriever.MongoDBAtlasFullTextRetriever", # noqa: E501 + "init_parameters": { + "document_store": { + "type": "haystack_integrations.document_stores.mongodb_atlas.document_store.MongoDBAtlasDocumentStore", # noqa: E501 + "init_parameters": { + "mongo_connection_string": { + "env_vars": ["MONGO_CONNECTION_STRING"], + "strict": True, + "type": "env_var", + }, + "database_name": "haystack_integration_test", + "collection_name": "test_full_text_collection", + "vector_search_index": "cosine_index", + "full_text_search_index": "full_text_index", + }, + }, + "filters": {"field": "value"}, + "top_k": 5, + }, + } + + retriever = MongoDBAtlasFullTextRetriever.from_dict(data) + document_store = retriever.document_store + + assert isinstance(document_store, MongoDBAtlasDocumentStore) + assert isinstance(document_store.mongo_connection_string, EnvVarSecret) + assert document_store.database_name == "haystack_integration_test" + assert document_store.collection_name == "test_full_text_collection" + assert document_store.vector_search_index == "cosine_index" + assert document_store.full_text_search_index == "full_text_index" + assert retriever.filters == {"field": "value"} + assert retriever.top_k == 5 + assert retriever.filter_policy == FilterPolicy.REPLACE + + def test_run(self): + mock_store = Mock(spec=MongoDBAtlasDocumentStore) + doc = Document(content="Lorem ipsum") + mock_store._fulltext_retrieval.return_value = [doc] + + retriever = MongoDBAtlasFullTextRetriever(document_store=mock_store) + res = retriever.run(query="Lorem ipsum") + + mock_store._fulltext_retrieval.assert_called_once_with( + query="Lorem ipsum", fuzzy=None, match_criteria=None, score=None, synonyms=None, filters={}, top_k=10 + ) + + assert res == {"documents": [doc]} + + def test_run_merge_policy_filter(self): + mock_store = Mock(spec=MongoDBAtlasDocumentStore) + doc = Document(content="Lorem ipsum") + mock_store._fulltext_retrieval.return_value = [doc] + + retriever = MongoDBAtlasFullTextRetriever( + document_store=mock_store, + filters={"field": "meta.some_field", "operator": "==", "value": "SomeValue"}, + filter_policy=FilterPolicy.MERGE, + ) + res = retriever.run( + query="Lorem ipsum", filters={"field": "meta.some_field", "operator": "==", "value": "Test"} + ) + # as the both init and run filters are filtering the same field, the run filter takes precedence + mock_store._fulltext_retrieval.assert_called_once_with( + query="Lorem ipsum", + fuzzy=None, + match_criteria=None, + score=None, + synonyms=None, + filters={"field": "meta.some_field", "operator": "==", "value": "Test"}, + top_k=10, + ) + + assert res == {"documents": [doc]} From e9e9c04228b85a0dbe0d794eefb97fc145fe5cc5 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 9 Dec 2024 18:10:41 +0100 Subject: [PATCH 124/229] chore(deps): bump readmeio/rdme from 8 to 9 (#1234) Bumps [readmeio/rdme](https://github.com/readmeio/rdme) from 8 to 9. - [Release notes](https://github.com/readmeio/rdme/releases) - [Changelog](https://github.com/readmeio/rdme/blob/next/CHANGELOG.md) - [Commits](https://github.com/readmeio/rdme/compare/v8...v9) --- updated-dependencies: - dependency-name: readmeio/rdme dependency-type: direct:production update-type: version-update:semver-major ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- .github/workflows/CI_readme_sync.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/CI_readme_sync.yml b/.github/workflows/CI_readme_sync.yml index c6204a9be..958cc12f6 100644 --- a/.github/workflows/CI_readme_sync.yml +++ b/.github/workflows/CI_readme_sync.yml @@ -81,6 +81,6 @@ jobs: ls tmp - name: Sync API docs with Haystack docs version ${{ matrix.hs-docs-version }} - uses: readmeio/rdme@v8 + uses: readmeio/rdme@v9 with: rdme: docs ${{ steps.pathfinder.outputs.project_path }}/tmp --key=${{ secrets.README_API_KEY }} --version=${{ matrix.hs-docs-version }} From 7a1297b483f93ee0fa669cfcbbdeb06531ee27b9 Mon Sep 17 00:00:00 2001 From: Stefano Fiorucci Date: Mon, 9 Dec 2024 18:38:41 +0100 Subject: [PATCH 125/229] use text instead of content in Cohere and Anthropic (#1237) --- .../anthropic/tests/test_chat_generator.py | 14 +++++++------- .../anthropic/tests/test_vertex_chat_generator.py | 4 ++-- .../generators/cohere/chat/chat_generator.py | 6 +++--- .../components/generators/cohere/generator.py | 2 +- .../cohere/tests/test_cohere_chat_generator.py | 14 +++++++------- 5 files changed, 20 insertions(+), 20 deletions(-) diff --git a/integrations/anthropic/tests/test_chat_generator.py b/integrations/anthropic/tests/test_chat_generator.py index 155cf7950..9a111fc9d 100644 --- a/integrations/anthropic/tests/test_chat_generator.py +++ b/integrations/anthropic/tests/test_chat_generator.py @@ -188,9 +188,9 @@ def test_default_inference_params(self, chat_messages): first_reply = replies[0] assert isinstance(first_reply, ChatMessage), "First reply is not a ChatMessage instance" - assert first_reply.content, "First reply has no content" + assert first_reply.text, "First reply has no text" assert ChatMessage.is_from(first_reply, ChatRole.ASSISTANT), "First reply is not from the assistant" - assert "paris" in first_reply.content.lower(), "First reply does not contain 'paris'" + assert "paris" in first_reply.text.lower(), "First reply does not contain 'paris'" assert first_reply.meta, "First reply has no metadata" @pytest.mark.skipif( @@ -221,9 +221,9 @@ def streaming_callback(chunk: StreamingChunk): first_reply = replies[0] assert isinstance(first_reply, ChatMessage), "First reply is not a ChatMessage instance" - assert first_reply.content, "First reply has no content" + assert first_reply.text, "First reply has no text" assert ChatMessage.is_from(first_reply, ChatRole.ASSISTANT), "First reply is not from the assistant" - assert "paris" in first_reply.content.lower(), "First reply does not contain 'paris'" + assert "paris" in first_reply.text.lower(), "First reply does not contain 'paris'" assert first_reply.meta, "First reply has no metadata" @pytest.mark.skipif( @@ -255,11 +255,11 @@ def test_tools_use(self): first_reply = replies[0] assert isinstance(first_reply, ChatMessage), "First reply is not a ChatMessage instance" - assert first_reply.content, "First reply has no content" + assert first_reply.text, "First reply has no text" assert ChatMessage.is_from(first_reply, ChatRole.ASSISTANT), "First reply is not from the assistant" - assert "get_stock_price" in first_reply.content.lower(), "First reply does not contain get_stock_price" + assert "get_stock_price" in first_reply.text.lower(), "First reply does not contain get_stock_price" assert first_reply.meta, "First reply has no metadata" - fc_response = json.loads(first_reply.content) + fc_response = json.loads(first_reply.text) assert "name" in fc_response, "First reply does not contain name of the tool" assert "input" in fc_response, "First reply does not contain input of the tool" diff --git a/integrations/anthropic/tests/test_vertex_chat_generator.py b/integrations/anthropic/tests/test_vertex_chat_generator.py index a67e801ad..fefb508ac 100644 --- a/integrations/anthropic/tests/test_vertex_chat_generator.py +++ b/integrations/anthropic/tests/test_vertex_chat_generator.py @@ -188,9 +188,9 @@ def test_default_inference_params(self, chat_messages): first_reply = replies[0] assert isinstance(first_reply, ChatMessage), "First reply is not a ChatMessage instance" - assert first_reply.content, "First reply has no content" + assert first_reply.text, "First reply has no text" assert ChatMessage.is_from(first_reply, ChatRole.ASSISTANT), "First reply is not from the assistant" - assert "paris" in first_reply.content.lower(), "First reply does not contain 'paris'" + assert "paris" in first_reply.text.lower(), "First reply does not contain 'paris'" assert first_reply.meta, "First reply has no metadata" # Anthropic messages API is similar for AnthropicVertex and Anthropic endpoint, diff --git a/integrations/cohere/src/haystack_integrations/components/generators/cohere/chat/chat_generator.py b/integrations/cohere/src/haystack_integrations/components/generators/cohere/chat/chat_generator.py index e635e291c..3fae30baa 100644 --- a/integrations/cohere/src/haystack_integrations/components/generators/cohere/chat/chat_generator.py +++ b/integrations/cohere/src/haystack_integrations/components/generators/cohere/chat/chat_generator.py @@ -136,7 +136,7 @@ def from_dict(cls, data: Dict[str, Any]) -> "CohereChatGenerator": def _message_to_dict(self, message: ChatMessage) -> Dict[str, str]: role = "User" if message.role == ChatRole.USER else "Chatbot" - chat_message = {"user_name": role, "text": message.content} + chat_message = {"user_name": role, "text": message.text} return chat_message @component.output_types(replies=List[ChatMessage]) @@ -157,7 +157,7 @@ def run(self, messages: List[ChatMessage], generation_kwargs: Optional[Dict[str, chat_history = [self._message_to_dict(m) for m in messages[:-1]] if self.streaming_callback: response = self.client.chat_stream( - message=messages[-1].content, + message=messages[-1].text, model=self.model, chat_history=chat_history, **generation_kwargs, @@ -190,7 +190,7 @@ def run(self, messages: List[ChatMessage], generation_kwargs: Optional[Dict[str, ) else: response = self.client.chat( - message=messages[-1].content, + message=messages[-1].text, model=self.model, chat_history=chat_history, **generation_kwargs, diff --git a/integrations/cohere/src/haystack_integrations/components/generators/cohere/generator.py b/integrations/cohere/src/haystack_integrations/components/generators/cohere/generator.py index e4eaf8670..4630962ba 100644 --- a/integrations/cohere/src/haystack_integrations/components/generators/cohere/generator.py +++ b/integrations/cohere/src/haystack_integrations/components/generators/cohere/generator.py @@ -67,4 +67,4 @@ def run(self, prompt: str): chat_message = ChatMessage.from_user(prompt) # Note we have to call super() like this because of the way components are dynamically built with the decorator results = super(CohereGenerator, self).run([chat_message]) # noqa - return {"replies": [results["replies"][0].content], "meta": [results["replies"][0].meta]} + return {"replies": [results["replies"][0].text], "meta": [results["replies"][0].meta]} diff --git a/integrations/cohere/tests/test_cohere_chat_generator.py b/integrations/cohere/tests/test_cohere_chat_generator.py index b7cc0534a..09f3708eb 100644 --- a/integrations/cohere/tests/test_cohere_chat_generator.py +++ b/integrations/cohere/tests/test_cohere_chat_generator.py @@ -169,7 +169,7 @@ def test_live_run(self): results = component.run(chat_messages) assert len(results["replies"]) == 1 message: ChatMessage = results["replies"][0] - assert "Paris" in message.content + assert "Paris" in message.text assert "usage" in message.meta assert "prompt_tokens" in message.meta["usage"] assert "completion_tokens" in message.meta["usage"] @@ -205,7 +205,7 @@ def __call__(self, chunk: StreamingChunk) -> None: assert len(results["replies"]) == 1 message: ChatMessage = results["replies"][0] - assert "Paris" in message.content + assert "Paris" in message.text assert message.meta["finish_reason"] == "COMPLETE" @@ -227,7 +227,7 @@ def test_live_run_with_connector(self): results = component.run(chat_messages, generation_kwargs={"connectors": [{"id": "web-search"}]}) assert len(results["replies"]) == 1 message: ChatMessage = results["replies"][0] - assert "Paris" in message.content + assert "Paris" in message.text assert message.meta["documents"] is not None assert "citations" in message.meta # Citations might be None @@ -253,7 +253,7 @@ def __call__(self, chunk: StreamingChunk) -> None: assert len(results["replies"]) == 1 message: ChatMessage = results["replies"][0] - assert "Paris" in message.content + assert "Paris" in message.text assert message.meta["finish_reason"] == "COMPLETE" @@ -291,10 +291,10 @@ def test_tools_use(self): first_reply = replies[0] assert isinstance(first_reply, ChatMessage), "First reply is not a ChatMessage instance" - assert first_reply.content, "First reply has no content" + assert first_reply.text, "First reply has no text" assert ChatMessage.is_from(first_reply, ChatRole.ASSISTANT), "First reply is not from the assistant" - assert "get_stock_price" in first_reply.content.lower(), "First reply does not contain get_stock_price" + assert "get_stock_price" in first_reply.text.lower(), "First reply does not contain get_stock_price" assert first_reply.meta, "First reply has no metadata" - fc_response = json.loads(first_reply.content) + fc_response = json.loads(first_reply.text) assert "name" in fc_response, "First reply does not contain name of the tool" assert "parameters" in fc_response, "First reply does not contain parameters of the tool" From e62cf79e1e3e9351e85268992c3d12fdce0a8d24 Mon Sep 17 00:00:00 2001 From: HaystackBot Date: Mon, 9 Dec 2024 17:41:37 +0000 Subject: [PATCH 126/229] Update the changelog --- integrations/cohere/CHANGELOG.md | 96 ++++++++++++++++++++++++++------ 1 file changed, 80 insertions(+), 16 deletions(-) diff --git a/integrations/cohere/CHANGELOG.md b/integrations/cohere/CHANGELOG.md index 3f36836cc..1d98408e9 100644 --- a/integrations/cohere/CHANGELOG.md +++ b/integrations/cohere/CHANGELOG.md @@ -1,5 +1,21 @@ # Changelog +## [integrations/cohere-v2.0.1] - 2024-12-09 + +### ⚙️ CI + +- Adopt uv as installer (#1142) + +### 🧹 Chores + +- Update ruff linting scripts and settings (#1105) +- Fix linting/isort (#1215) + +### 🌀 Miscellaneous + +- Chore: use class methods to create `ChatMessage` (#1222) +- Chore: use `text` instead of `content` for `ChatMessage` in Cohere and Anthropic (#1237) + ## [integrations/cohere-v2.0.0] - 2024-09-16 ### 🚀 Features @@ -16,28 +32,49 @@ - Do not retry tests in `hatch run test` command (#954) -### ⚙️ Miscellaneous Tasks +### ⚙️ CI - Retry tests to reduce flakyness (#836) + +### 🧹 Chores + - Update ruff invocation to include check parameter (#853) -### Docs +### 🌀 Miscellaneous +- Ci: install `pytest-rerunfailures` where needed; add retry config to `test-cov` script (#845) - Update CohereChatGenerator docstrings (#958) - Update CohereGenerator docstrings (#960) ## [integrations/cohere-v1.1.1] - 2024-06-12 +### 🌀 Miscellaneous + +- Chore: `CohereGenerator` - remove warning about `generate` API (#805) + ## [integrations/cohere-v1.1.0] - 2024-05-24 ### 🐛 Bug Fixes - Remove support for generate API (#755) +### 🌀 Miscellaneous + +- Chore: change the pydoc renderer class (#718) + ## [integrations/cohere-v1.0.0] - 2024-05-03 +### 🌀 Miscellaneous + +- Follow up: update Cohere integration to use Cohere SDK v5 (#711) + ## [integrations/cohere-v0.7.0] - 2024-05-02 +### 🌀 Miscellaneous + +- Chore: add license classifiers (#680) +- Update Cohere integration to use Cohere SDK v5 (#702) + ## [integrations/cohere-v0.6.0] - 2024-04-08 ### 🚀 Features @@ -46,21 +83,17 @@ ## [integrations/cohere-v0.5.0] - 2024-03-29 +### 🌀 Miscellaneous + +- Add the Cohere client name to cohere requests (#362) + ## [integrations/cohere-v0.4.1] - 2024-03-21 ### 🐛 Bug Fixes - Fix order of API docs (#447) - -This PR will also push the docs to Readme - Fix tests (#561) -* fix unit tests - -* try - -* remove flaky check - ### 📚 Documentation - Update category slug (#442) @@ -68,14 +101,20 @@ This PR will also push the docs to Readme - Small consistency improvements (#536) - Disable-class-def (#556) -### ⚙️ Miscellaneous Tasks +### 🧹 Chores - Update Cohere integration to use new generic callable (de)serializers for their callback handlers (#453) - Use `serialize_callable` instead of `serialize_callback_handler` in Cohere (#460) -### Cohere +### 🌀 Miscellaneous +- Choere - remove matching error message from tests (#419) - Fix linting (#509) +- Make tests show coverage (#566) +- Refactor tests (#574) +- Test: relax test constraints (#591) +- Remove references to Python 3.7 (#601) +- Fix: Pin cohere version (#609) ## [integrations/cohere-v0.4.0] - 2024-02-12 @@ -92,32 +131,57 @@ This PR will also push the docs to Readme - Fix failing `TestCohereChatGenerator.test_from_dict_fail_wo_env_var` test (#393) -## [integrations/cohere-v0.3.0] - 2024-01-25 +### 🌀 Miscellaneous -### 🐛 Bug Fixes +- Cohere: generate api docs (#321) +- Fix: update to latest haystack-ai version (#348) -- Fix project urls (#96) +## [integrations/cohere-v0.3.0] - 2024-01-25 +### 🐛 Bug Fixes +- Fix project URLs (#96) - Cohere namespace reorg (#271) ### 🚜 Refactor - Use `hatch_vcs` to manage integrations versioning (#103) -### ⚙️ Miscellaneous Tasks +### 🧹 Chores - [**breaking**] Rename `model_name` to `model` in the Cohere integration (#222) - Cohere namespace change (#247) +### 🌀 Miscellaneous + +- Cohere: remove unused constant (#91) +- Change default 'input_type' for CohereTextEmbedder (#99) +- Change metadata to meta (#152) +- Add cohere chat generator (#88) +- Optimize API key reading (#162) +- Cohere - change metadata to meta (#178) + ## [integrations/cohere-v0.2.0] - 2023-12-11 ### 🚀 Features - Add support for V3 Embed models to CohereEmbedders (#89) +### 🌀 Miscellaneous + +- Cohere: increase version to prepare release (#92) + ## [integrations/cohere-v0.1.1] - 2023-12-07 +### 🌀 Miscellaneous + +- [cohere] Add text and document embedders (#80) +- [cohere] fix cohere pypi version badge and add Embedder note (#86) + ## [integrations/cohere-v0.0.1] - 2023-12-04 +### 🌀 Miscellaneous + +- Add `cohere_haystack` integration package (#75) + From 92830654d2db3350a48a7eabc337c8bc8075612b Mon Sep 17 00:00:00 2001 From: Stefano Fiorucci Date: Mon, 9 Dec 2024 19:51:32 +0100 Subject: [PATCH 127/229] chroma: unpin tokenizers (#1233) --- integrations/chroma/pyproject.toml | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/integrations/chroma/pyproject.toml b/integrations/chroma/pyproject.toml index c91cc6cb0..40bc9a2b3 100644 --- a/integrations/chroma/pyproject.toml +++ b/integrations/chroma/pyproject.toml @@ -25,9 +25,8 @@ classifiers = [ dependencies = [ "haystack-ai", "chromadb>=0.5.17", - "typing_extensions>=4.8.0", - "tokenizers>=0.13.2,<=0.20.3" # TODO: remove when Chroma pins tokenizers internally -] + "typing_extensions>=4.8.0" + ] [project.urls] Documentation = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/chroma#readme" From d3677be7171425aab8d6523605bfb3715f0d640b Mon Sep 17 00:00:00 2001 From: isfuku <54598113+isfuku@users.noreply.github.com> Date: Tue, 10 Dec 2024 10:47:44 -0300 Subject: [PATCH 128/229] feat: warn if LangfuseTracer initialized without tracing enabled (#1231) * feat: warn if LangfuseTracer initialized without tracing enabled * test: warn when lagnfuse tracer init with tracing disabled --- .../tracing/langfuse/tracer.py | 11 +++++++++-- integrations/langfuse/tests/test_tracer.py | 16 ++++++++++++++++ 2 files changed, 25 insertions(+), 2 deletions(-) diff --git a/integrations/langfuse/src/haystack_integrations/tracing/langfuse/tracer.py b/integrations/langfuse/src/haystack_integrations/tracing/langfuse/tracer.py index d6f2535c7..6af05633e 100644 --- a/integrations/langfuse/src/haystack_integrations/tracing/langfuse/tracer.py +++ b/integrations/langfuse/src/haystack_integrations/tracing/langfuse/tracer.py @@ -7,7 +7,8 @@ from haystack import logging from haystack.components.generators.openai_utils import _convert_message_to_openai_format from haystack.dataclasses import ChatMessage -from haystack.tracing import Span, Tracer, tracer +from haystack.tracing import Span, Tracer +from haystack.tracing import tracer as proxy_tracer from haystack.tracing import utils as tracing_utils import langfuse @@ -78,7 +79,7 @@ def set_content_tag(self, key: str, value: Any) -> None: :param key: The content tag key. :param value: The content tag value. """ - if not tracer.is_content_tracing_enabled: + if not proxy_tracer.is_content_tracing_enabled: return if key.endswith(".input"): if "messages" in value: @@ -126,6 +127,12 @@ def __init__(self, tracer: "langfuse.Langfuse", name: str = "Haystack", public: be publicly accessible to anyone with the tracing URL. If set to `False`, the tracing data will be private and only accessible to the Langfuse account owner. """ + if not proxy_tracer.is_content_tracing_enabled: + logger.warning( + "Traces will not be logged to Langfuse because Haystack tracing is disabled. " + "To enable, set the HAYSTACK_CONTENT_TRACING_ENABLED environment variable to true " + "before importing Haystack." + ) self._tracer = tracer self._context: List[LangfuseSpan] = [] self._name = name diff --git a/integrations/langfuse/tests/test_tracer.py b/integrations/langfuse/tests/test_tracer.py index 42ae1d07d..d9790ea36 100644 --- a/integrations/langfuse/tests/test_tracer.py +++ b/integrations/langfuse/tests/test_tracer.py @@ -1,4 +1,6 @@ import datetime +import logging +import sys from unittest.mock import MagicMock, Mock, patch from haystack.dataclasses import ChatMessage @@ -149,3 +151,17 @@ def test_context_is_empty_after_tracing(self): pass assert tracer._context == [] + + def test_init_with_tracing_disabled(self, monkeypatch, caplog): + # Clear haystack modules because ProxyTracer is initialized whenever haystack is imported + modules_to_clear = [name for name in sys.modules if name.startswith('haystack')] + for name in modules_to_clear: + sys.modules.pop(name, None) + + # Re-import LangfuseTracer and instantiate it with tracing disabled + with caplog.at_level(logging.WARNING): + monkeypatch.setenv("HAYSTACK_CONTENT_TRACING_ENABLED", "false") + from haystack_integrations.tracing.langfuse import LangfuseTracer + + LangfuseTracer(tracer=MockTracer(), name="Haystack", public=False) + assert "tracing is disabled" in caplog.text From 54a0573e7c0c047b37a70877cce8ff49f09924e8 Mon Sep 17 00:00:00 2001 From: Stefano Fiorucci Date: Tue, 10 Dec 2024 14:55:18 +0100 Subject: [PATCH 129/229] chore: use text instead of content for ChatMessage in Llama.cpp, Langfuse and Mistral (#1238) --- integrations/langfuse/tests/test_tracing.py | 2 +- .../llama_cpp/chat/chat_generator.py | 2 +- .../llama_cpp/tests/test_chat_generator.py | 30 +++++++++---------- .../tests/test_mistral_chat_generator.py | 4 +-- 4 files changed, 19 insertions(+), 19 deletions(-) diff --git a/integrations/langfuse/tests/test_tracing.py b/integrations/langfuse/tests/test_tracing.py index e5737b861..75c1b7a13 100644 --- a/integrations/langfuse/tests/test_tracing.py +++ b/integrations/langfuse/tests/test_tracing.py @@ -49,7 +49,7 @@ def test_tracing_integration(llm_class, env_var, expected_trace): "tracer": {"invocation_context": {"user_id": "user_42"}}, } ) - assert "Berlin" in response["llm"]["replies"][0].content + assert "Berlin" in response["llm"]["replies"][0].text assert response["tracer"]["trace_url"] trace_url = response["tracer"]["trace_url"] diff --git a/integrations/llama_cpp/src/haystack_integrations/components/generators/llama_cpp/chat/chat_generator.py b/integrations/llama_cpp/src/haystack_integrations/components/generators/llama_cpp/chat/chat_generator.py index d43700215..014dd7169 100644 --- a/integrations/llama_cpp/src/haystack_integrations/components/generators/llama_cpp/chat/chat_generator.py +++ b/integrations/llama_cpp/src/haystack_integrations/components/generators/llama_cpp/chat/chat_generator.py @@ -17,7 +17,7 @@ def _convert_message_to_llamacpp_format(message: ChatMessage) -> Dict[str, str]: - `content` - `name` (optional) """ - formatted_msg = {"role": message.role.value, "content": message.content} + formatted_msg = {"role": message.role.value, "content": message.text} if message.name: formatted_msg["name"] = message.name diff --git a/integrations/llama_cpp/tests/test_chat_generator.py b/integrations/llama_cpp/tests/test_chat_generator.py index 802fe9128..0ddd78c4f 100644 --- a/integrations/llama_cpp/tests/test_chat_generator.py +++ b/integrations/llama_cpp/tests/test_chat_generator.py @@ -163,7 +163,7 @@ def test_run_with_valid_message(self, generator_mock): assert isinstance(result["replies"], list) assert len(result["replies"]) == 1 assert isinstance(result["replies"][0], ChatMessage) - assert result["replies"][0].content == "Generated text" + assert result["replies"][0].text == "Generated text" assert result["replies"][0].role == ChatRole.ASSISTANT def test_run_with_generation_kwargs(self, generator_mock): @@ -183,7 +183,7 @@ def test_run_with_generation_kwargs(self, generator_mock): mock_model.create_chat_completion.return_value = mock_output generation_kwargs = {"max_tokens": 128} result = generator.run([ChatMessage.from_system("Write a 200 word paragraph.")], generation_kwargs) - assert result["replies"][0].content == "Generated text" + assert result["replies"][0].text == "Generated text" assert result["replies"][0].meta["finish_reason"] == "length" @pytest.mark.integration @@ -206,7 +206,7 @@ def test_run(self, generator): assert "replies" in result assert isinstance(result["replies"], list) assert len(result["replies"]) > 0 - assert any(answer.lower() in reply.content.lower() for reply in result["replies"]) + assert any(answer.lower() in reply.text.lower() for reply in result["replies"]) @pytest.mark.integration def test_run_rag_pipeline(self, generator): @@ -270,7 +270,7 @@ def test_run_rag_pipeline(self, generator): replies = result["llm"]["replies"] assert len(replies) > 0 - assert any("bioluminescent waves" in reply.content for reply in replies) + assert any("bioluminescent waves" in reply.text.lower() for reply in replies) assert all(reply.role == ChatRole.ASSISTANT for reply in replies) @pytest.mark.integration @@ -308,15 +308,15 @@ def test_json_constraining(self, generator): assert len(result["replies"]) > 0 assert all(reply.role == ChatRole.ASSISTANT for reply in result["replies"]) for reply in result["replies"]: - assert json.loads(reply.content) - assert isinstance(json.loads(reply.content), dict) - assert "people" in json.loads(reply.content) - assert isinstance(json.loads(reply.content)["people"], list) - assert all(isinstance(person, dict) for person in json.loads(reply.content)["people"]) - assert all("name" in person for person in json.loads(reply.content)["people"]) - assert all("age" in person for person in json.loads(reply.content)["people"]) - assert all(isinstance(person["name"], str) for person in json.loads(reply.content)["people"]) - assert all(isinstance(person["age"], int) for person in json.loads(reply.content)["people"]) + assert json.loads(reply.text) + assert isinstance(json.loads(reply.text), dict) + assert "people" in json.loads(reply.text) + assert isinstance(json.loads(reply.text)["people"], list) + assert all(isinstance(person, dict) for person in json.loads(reply.text)["people"]) + assert all("name" in person for person in json.loads(reply.text)["people"]) + assert all("age" in person for person in json.loads(reply.text)["people"]) + assert all(isinstance(person["name"], str) for person in json.loads(reply.text)["people"]) + assert all(isinstance(person["age"], int) for person in json.loads(reply.text)["people"]) class TestLlamaCppChatGeneratorFunctionary: @@ -431,8 +431,8 @@ def test_function_call_and_execute(self, generator): second_response = generator.run(messages=messages) assert "replies" in second_response assert len(second_response["replies"]) > 0 - assert any("San Francisco" in reply.content for reply in second_response["replies"]) - assert any("72" in reply.content for reply in second_response["replies"]) + assert any("San Francisco" in reply.text for reply in second_response["replies"]) + assert any("72" in reply.text for reply in second_response["replies"]) class TestLlamaCppChatGeneratorChatML: diff --git a/integrations/mistral/tests/test_mistral_chat_generator.py b/integrations/mistral/tests/test_mistral_chat_generator.py index 3c95f19db..6277b9c36 100644 --- a/integrations/mistral/tests/test_mistral_chat_generator.py +++ b/integrations/mistral/tests/test_mistral_chat_generator.py @@ -214,7 +214,7 @@ def test_live_run(self): results = component.run(chat_messages) assert len(results["replies"]) == 1 message: ChatMessage = results["replies"][0] - assert "Paris" in message.content + assert "Paris" in message.text assert "mistral-tiny" in message.meta["model"] assert message.meta["finish_reason"] == "stop" @@ -249,7 +249,7 @@ def __call__(self, chunk: StreamingChunk) -> None: assert len(results["replies"]) == 1 message: ChatMessage = results["replies"][0] - assert "Paris" in message.content + assert "Paris" in message.text assert "mistral-tiny" in message.meta["model"] assert message.meta["finish_reason"] == "stop" From d22deba6ef45839cd382732fbde08aa313ac6fe4 Mon Sep 17 00:00:00 2001 From: HaystackBot Date: Tue, 10 Dec 2024 13:56:41 +0000 Subject: [PATCH 130/229] Update the changelog --- integrations/llama_cpp/CHANGELOG.md | 59 +++++++++++++++++++++++++---- 1 file changed, 51 insertions(+), 8 deletions(-) diff --git a/integrations/llama_cpp/CHANGELOG.md b/integrations/llama_cpp/CHANGELOG.md index ea4c05e4d..2d4a8c86e 100644 --- a/integrations/llama_cpp/CHANGELOG.md +++ b/integrations/llama_cpp/CHANGELOG.md @@ -1,46 +1,89 @@ # Changelog +## [integrations/llama_cpp-v0.4.2] - 2024-12-10 + +### 🧪 Testing + +- Do not retry tests in `hatch run test` command (#954) + +### ⚙️ CI + +- Adopt uv as installer (#1142) + +### 🧹 Chores + +- Update ruff linting scripts and settings (#1105) +- Unpin `llama-cpp-python` (#1115) +- Fix linting/isort (#1215) +- Use text instead of content for ChatMessage in Llama.cpp, Langfuse and Mistral (#1238) + +### 🌀 Miscellaneous + +- Chore: lamma_cpp - ruff update, don't ruff tests (#998) +- Fix: pin `llama-cpp-python<0.3.0` (#1111) + ## [integrations/llama_cpp-v0.4.1] - 2024-08-08 ### 🐛 Bug Fixes - Replace DynamicChatPromptBuilder with ChatPromptBuilder (#940) -### ⚙️ Miscellaneous Tasks +### ⚙️ CI - Retry tests to reduce flakyness (#836) + +### 🧹 Chores + - Update ruff invocation to include check parameter (#853) - Pin `llama-cpp-python>=0.2.87` (#955) -## [integrations/llama_cpp-v0.4.0] - 2024-05-13 +### 🌀 Miscellaneous -### 🐛 Bug Fixes +- Ci: install `pytest-rerunfailures` where needed; add retry config to `test-cov` script (#845) +- Fix: pin llama-cpp-python to an older version (#943) +- Refactor: introduce `_convert_message_to_llamacpp_format` utility function (#939) -- Fix commit (#436) +## [integrations/llama_cpp-v0.4.0] - 2024-05-13 +### 🐛 Bug Fixes +- Llama.cpp: change wrong links and imports (#436) - Fix order of API docs (#447) -This PR will also push the docs to Readme - ### 📚 Documentation - Update category slug (#442) - Small consistency improvements (#536) - Disable-class-def (#556) -### ⚙️ Miscellaneous Tasks +### 🧹 Chores - [**breaking**] Rename model_path to model in the Llama.cpp integration (#243) -### Llama.cpp +### 🌀 Miscellaneous - Generate api docs (#353) +- Model_name_or_path > model (#418) +- Llama.cpp - review docstrings (#510) +- Llama.cpp - update examples (#511) +- Make tests show coverage (#566) +- Remove references to Python 3.7 (#601) +- Chore: add license classifiers (#680) +- Chore: change the pydoc renderer class (#718) +- Basic implementation of llama.cpp chat generation (#723) ## [integrations/llama_cpp-v0.2.1] - 2024-01-18 +### 🌀 Miscellaneous + +- Update import paths for beta5 (#233) + ## [integrations/llama_cpp-v0.2.0] - 2024-01-17 +### 🌀 Miscellaneous + +- Mount llama_cpp in haystack_integrations (#217) + ## [integrations/llama_cpp-v0.1.0] - 2024-01-09 ### 🚀 Features From df14a979e05879a0165f1dbe7c2737a063c188d2 Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Tue, 10 Dec 2024 15:40:27 +0100 Subject: [PATCH 131/229] feat: Update AmazonBedrockChatGenerator to use Converse API (BREAKING CHANGE) (#1219) * Initial commit * Update models tested * Add tool support * Update Amazon Bedrock model names in tests * Support for tool streaming * Format * Minot test updates * Lint * Remove truncate init parameter * Pull try down * Add extract_replies_from_response unit test * Add process_streaming_response unit test * Lint * Small test fix * Use EventStream from botocore * Update integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py Co-authored-by: Stefano Fiorucci * Update integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py Co-authored-by: Stefano Fiorucci --------- Co-authored-by: Stefano Fiorucci --- .../amazon_bedrock/chat/adapters.py | 569 -------------- .../amazon_bedrock/chat/chat_generator.py | 294 +++++--- .../tests/test_chat_generator.py | 704 +++++++----------- 3 files changed, 467 insertions(+), 1100 deletions(-) delete mode 100644 integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/adapters.py diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/adapters.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/adapters.py deleted file mode 100644 index cbb5ee370..000000000 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/adapters.py +++ /dev/null @@ -1,569 +0,0 @@ -import json -import logging -import os -from abc import ABC, abstractmethod -from typing import Any, Callable, ClassVar, Dict, List, Optional - -from botocore.eventstream import EventStream -from haystack.components.generators.openai_utils import _convert_message_to_openai_format -from haystack.dataclasses import ChatMessage, ChatRole, StreamingChunk -from transformers import AutoTokenizer, PreTrainedTokenizer - -from haystack_integrations.components.generators.amazon_bedrock.handlers import DefaultPromptHandler - -logger = logging.getLogger(__name__) - - -class BedrockModelChatAdapter(ABC): - """ - Base class for Amazon Bedrock chat model adapters. - - Each subclass of this class is designed to address the unique specificities of a particular chat LLM it adapts, - focusing on preparing the requests and extracting the responses from the Amazon Bedrock hosted chat LLMs. - """ - - def __init__(self, truncate: Optional[bool], generation_kwargs: Dict[str, Any]) -> None: - """ - Initializes the chat adapter with the truncate parameter and generation kwargs. - """ - self.generation_kwargs = generation_kwargs - self.truncate = truncate - - @abstractmethod - def prepare_body(self, messages: List[ChatMessage], **inference_kwargs) -> Dict[str, Any]: - """ - Prepares the body for the Amazon Bedrock request. - Subclasses should override this method to package the chat messages into the request. - - :param messages: The chat messages to package into the request. - :param inference_kwargs: Additional inference kwargs to use. - :returns: The prepared body. - """ - - def get_responses(self, response_body: Dict[str, Any]) -> List[ChatMessage]: - """ - Extracts the responses from the Amazon Bedrock response. - - :param response_body: The response body. - :returns: The extracted responses. - """ - return self._extract_messages_from_response(response_body) - - def get_stream_responses( - self, stream: EventStream, streaming_callback: Callable[[StreamingChunk], None] - ) -> List[ChatMessage]: - streaming_chunks: List[StreamingChunk] = [] - last_decoded_chunk: Dict[str, Any] = {} - for event in stream: - chunk = event.get("chunk") - if chunk: - last_decoded_chunk = json.loads(chunk["bytes"].decode("utf-8")) - streaming_chunk = self._build_streaming_chunk(last_decoded_chunk) - streaming_callback(streaming_chunk) # callback the stream handler with StreamingChunk - streaming_chunks.append(streaming_chunk) - responses = ["".join(chunk.content for chunk in streaming_chunks).lstrip()] - return [ChatMessage.from_assistant(response, meta=last_decoded_chunk) for response in responses] - - @staticmethod - def _update_params(target_dict: Dict[str, Any], updates_dict: Dict[str, Any], allowed_params: List[str]) -> None: - """ - Updates target_dict with values from updates_dict. Merges lists instead of overriding them. - - :param target_dict: The dictionary to update. - :param updates_dict: The dictionary with updates. - :param allowed_params: The list of allowed params to use. - """ - for key, value in updates_dict.items(): - if key not in allowed_params: - logger.warning(f"Parameter '{key}' is not allowed and will be ignored.") - continue - if key in target_dict and isinstance(target_dict[key], list) and isinstance(value, list): - # Merge lists and remove duplicates - target_dict[key] = sorted(set(target_dict[key] + value)) - else: - # Override the value in target_dict - target_dict[key] = value - - def _get_params( - self, inference_kwargs: Dict[str, Any], default_params: Dict[str, Any], allowed_params: List[str] - ) -> Dict[str, Any]: - """ - Merges params from inference_kwargs with the default params and self.generation_kwargs. - Uses a helper function to merge lists or override values as necessary. - - :param inference_kwargs: The inference kwargs to merge. - :param default_params: The default params to start with. - :param allowed_params: The list of allowed params to use. - :returns: The merged params. - """ - # Start with a copy of default_params - kwargs = default_params.copy() - - # Update the default params with self.generation_kwargs and finally inference_kwargs - self._update_params(kwargs, self.generation_kwargs, allowed_params) - self._update_params(kwargs, inference_kwargs, allowed_params) - - return kwargs - - def _ensure_token_limit(self, prompt: str) -> str: - """ - Ensures that the prompt is within the token limit for the model. - :param prompt: The prompt to check. - :returns: The resized prompt. - """ - resize_info = self.check_prompt(prompt) - if resize_info["prompt_length"] != resize_info["new_prompt_length"]: - logger.warning( - "The prompt was truncated from %s tokens to %s tokens so that the prompt length and " - "the answer length (%s tokens) fit within the model's max token limit (%s tokens). " - "Shorten the prompt or it will be cut off.", - resize_info["prompt_length"], - max(0, resize_info["model_max_length"] - resize_info["max_length"]), # type: ignore - resize_info["max_length"], - resize_info["model_max_length"], - ) - return str(resize_info["resized_prompt"]) - - @abstractmethod - def check_prompt(self, prompt: str) -> Dict[str, Any]: - """ - Checks the prompt length and resizes it if necessary. If the prompt is too long, it will be truncated. - - :param prompt: The prompt to check. - :returns: A dictionary containing the resized prompt and additional information. - """ - - @abstractmethod - def _extract_messages_from_response(self, response_body: Dict[str, Any]) -> List[ChatMessage]: - """ - Extracts the messages from the response body. - - :param response_body: The response body. - :returns: The extracted ChatMessage list. - """ - - @abstractmethod - def _build_streaming_chunk(self, chunk: Dict[str, Any]) -> StreamingChunk: - """ - Extracts the content and meta from a streaming chunk. - - :param chunk: The streaming chunk as dict. - :returns: A StreamingChunk object. - """ - - -class AnthropicClaudeChatAdapter(BedrockModelChatAdapter): - """ - Model adapter for the Anthropic Claude chat model. - """ - - # https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html - ALLOWED_PARAMS: ClassVar[List[str]] = [ - "anthropic_version", - "max_tokens", - "stop_sequences", - "temperature", - "top_p", - "top_k", - "system", - "tools", - "tool_choice", - ] - - def __init__(self, truncate: Optional[bool], generation_kwargs: Dict[str, Any]): - """ - Initializes the Anthropic Claude chat adapter. - - :param truncate: Whether to truncate the prompt if it exceeds the model's max token limit. - :param generation_kwargs: The generation kwargs. - """ - super().__init__(truncate, generation_kwargs) - - # We pop the model_max_length as it is not sent to the model - # but used to truncate the prompt if needed - # Anthropic Claude has a limit of at least 100000 tokens - # https://docs.anthropic.com/claude/reference/input-and-output-sizes - model_max_length = self.generation_kwargs.pop("model_max_length", 100000) - - # Truncate prompt if prompt tokens > model_max_length-max_length - # (max_length is the length of the generated text) - # TODO use Anthropic tokenizer to get the precise prompt length - # See https://github.com/anthropics/anthropic-sdk-python?tab=readme-ov-file#token-counting - self.prompt_handler = DefaultPromptHandler( - tokenizer="gpt2", - model_max_length=model_max_length, - max_length=self.generation_kwargs.get("max_tokens") or 512, - ) - - def prepare_body(self, messages: List[ChatMessage], **inference_kwargs) -> Dict[str, Any]: - """ - Prepares the body for the Anthropic Claude request. - - :param messages: The chat messages to package into the request. - :param inference_kwargs: Additional inference kwargs to use. - :returns: The prepared body. - """ - default_params = { - "anthropic_version": self.generation_kwargs.get("anthropic_version") or "bedrock-2023-05-31", - "max_tokens": self.generation_kwargs.get("max_tokens") or 512, # max_tokens is required - } - - # combine stop words with default stop sequences, remove stop_words as Anthropic does not support it - stop_sequences = inference_kwargs.get("stop_sequences", []) + inference_kwargs.pop("stop_words", []) - if stop_sequences: - inference_kwargs["stop_sequences"] = stop_sequences - # pop stream kwarg from inference_kwargs as Anthropic does not support it (if provided) - inference_kwargs.pop("stream", None) - params = self._get_params(inference_kwargs, default_params, self.ALLOWED_PARAMS) - body = {**self.prepare_chat_messages(messages=messages), **params} - return body - - def prepare_chat_messages(self, messages: List[ChatMessage]) -> Dict[str, Any]: - """ - Prepares the chat messages for the Anthropic Claude request. - - :param messages: The chat messages to prepare. - :returns: The prepared chat messages as a dictionary. - """ - body: Dict[str, Any] = {} - system = messages[0].content if messages and messages[0].is_from(ChatRole.SYSTEM) else None - body["messages"] = [ - self._to_anthropic_message(m) for m in messages if m.is_from(ChatRole.USER) or m.is_from(ChatRole.ASSISTANT) - ] - if system: - body["system"] = system - # Ensure token limit for each message in the body - if self.truncate: - for message in body["messages"]: - for content in message["content"]: - content["text"] = self._ensure_token_limit(content["text"]) - return body - - def check_prompt(self, prompt: str) -> Dict[str, Any]: - """ - Checks the prompt length and resizes it if necessary. If the prompt is too long, it will be truncated. - - :param prompt: The prompt to check. - :returns: A dictionary containing the resized prompt and additional information. - """ - return self.prompt_handler(prompt) - - def _extract_messages_from_response(self, response_body: Dict[str, Any]) -> List[ChatMessage]: - """ - Extracts the messages from the response body. - - :param response_body: The response body. - :return: The extracted ChatMessage list. - """ - messages: List[ChatMessage] = [] - if response_body.get("type") == "message": - if response_body.get("stop_reason") == "tool_use": # If `tool_use` we only keep the tool_use content - for content in response_body["content"]: - if content.get("type") == "tool_use": - meta = {k: v for k, v in response_body.items() if k not in ["type", "content", "role"]} - json_answer = json.dumps(content) - messages.append(ChatMessage.from_assistant(json_answer, meta=meta)) - else: # For other stop_reason, return all text content - for content in response_body["content"]: - if content.get("type") == "text": - meta = {k: v for k, v in response_body.items() if k not in ["type", "content", "role"]} - messages.append(ChatMessage.from_assistant(content["text"], meta=meta)) - - return messages - - def _build_streaming_chunk(self, chunk: Dict[str, Any]) -> StreamingChunk: - """ - Extracts the content and meta from a streaming chunk. - - :param chunk: The streaming chunk as dict. - :returns: A StreamingChunk object. - """ - if chunk.get("type") == "content_block_delta" and chunk.get("delta", {}).get("type") == "text_delta": - return StreamingChunk(content=chunk.get("delta", {}).get("text", ""), meta=chunk) - return StreamingChunk(content="", meta=chunk) - - def _to_anthropic_message(self, m: ChatMessage) -> Dict[str, Any]: - """ - Convert a ChatMessage to a dictionary with the content and role fields. - :param m: The ChatMessage to convert. - :return: The dictionary with the content and role fields. - """ - return {"content": [{"type": "text", "text": m.content}], "role": m.role.value} - - -class MistralChatAdapter(BedrockModelChatAdapter): - """ - Model adapter for the Mistral chat model. - """ - - chat_template = """ - {% if messages[0]['role'] == 'system' %} - {% set loop_messages = messages[1:] %} - {% set system_message = messages[0]['content'] %} - {% else %} - {% set loop_messages = messages %} - {% set system_message = false %} - {% endif %} - {{bos_token}} - {% for message in loop_messages %} - {% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %} - {{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }} - {% endif %} - {% if loop.index0 == 0 and system_message != false %} - {% set content = system_message + '\n' + message['content'] %} - {% else %} - {% set content = message['content'] %} - {% endif %} - {% if message['role'] == 'user' %} - {{ '[INST] ' + content.strip() + ' [/INST]' }} - {% elif message['role'] == 'assistant' %} - {{ content.strip() + eos_token }} - {% endif %} - {% endfor %} - """ - chat_template = "".join(line.strip() for line in chat_template.splitlines()) - - # the above template was designed to match https://docs.mistral.ai/models/#chat-template - # and to support system messages, otherwise we could use the default mistral chat template - # available on HF infrastructure - - # https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html - ALLOWED_PARAMS: ClassVar[List[str]] = [ - "max_tokens", - "safe_prompt", - "random_seed", - "temperature", - "top_p", - ] - - def __init__(self, truncate: Optional[bool], generation_kwargs: Dict[str, Any]): - """ - Initializes the Mistral chat adapter. - :param truncate: Whether to truncate the prompt if it exceeds the model's max token limit. - :param generation_kwargs: The generation kwargs. - """ - super().__init__(truncate, generation_kwargs) - - # We pop the model_max_length as it is not sent to the model - # but used to truncate the prompt if needed - # Mistral has a limit of at least 32000 tokens - model_max_length = self.generation_kwargs.pop("model_max_length", 32000) - - # Use `mistralai/Mistral-7B-v0.1` as tokenizer, all mistral models likely use the same tokenizer - # a) we should get good estimates for the prompt length - # b) we can use apply_chat_template with the template above to delineate ChatMessages - # Mistral models are gated on HF Hub. If no HF_TOKEN is found we use a non-gated alternative tokenizer model. - tokenizer: PreTrainedTokenizer - if os.environ.get("HF_TOKEN") or os.environ.get("HF_API_TOKEN"): - tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1") - else: - tokenizer = AutoTokenizer.from_pretrained("NousResearch/Llama-2-7b-chat-hf") - logger.warning( - "Gated mistralai/Mistral-7B-Instruct-v0.1 model cannot be used as a tokenizer for " - "estimating the prompt length because no HF_TOKEN was found. Using " - "NousResearch/Llama-2-7b-chat-hf instead. To use a mistral tokenizer export an env var " - "HF_TOKEN containing a Hugging Face token and make sure you have access to the model." - ) - - self.prompt_handler = DefaultPromptHandler( - tokenizer=tokenizer, - model_max_length=model_max_length, - max_length=self.generation_kwargs.get("max_tokens") or 512, - ) - - def prepare_body(self, messages: List[ChatMessage], **inference_kwargs) -> Dict[str, Any]: - """ - Prepares the body for the Mistral request. - - :param messages: The chat messages to package into the request. - :param inference_kwargs: Additional inference kwargs to use. - :returns: The prepared body. - """ - default_params = { - "max_tokens": self.generation_kwargs.get("max_tokens") or 512, # max_tokens is required - } - # replace stop_words from inference_kwargs with stop, as this is Mistral specific parameter - stop_words = inference_kwargs.pop("stop_words", []) - if stop_words: - inference_kwargs["stop"] = stop_words - - # pop stream kwarg from inference_kwargs as Mistral does not support it (if provided) - inference_kwargs.pop("stream", None) - - params = self._get_params(inference_kwargs, default_params, self.ALLOWED_PARAMS) - body = {"prompt": self.prepare_chat_messages(messages=messages), **params} - return body - - def prepare_chat_messages(self, messages: List[ChatMessage]) -> str: - """ - Prepares the chat messages for the Mistral request. - - :param messages: The chat messages to prepare. - :returns: The prepared chat messages as a string. - """ - # it would be great to use the default mistral chat template, but it doesn't support system messages - # the class variable defined chat_template is a workaround to support system messages - # default is https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1/blob/main/tokenizer_config.json - # but we'll use our custom chat template - prepared_prompt: str = self.prompt_handler.tokenizer.apply_chat_template( - conversation=[_convert_message_to_openai_format(m) for m in messages], - tokenize=False, - chat_template=self.chat_template, - ) - if self.truncate: - prepared_prompt = self._ensure_token_limit(prepared_prompt) - return prepared_prompt - - def check_prompt(self, prompt: str) -> Dict[str, Any]: - """ - Checks the prompt length and resizes it if necessary. If the prompt is too long, it will be truncated. - - :param prompt: The prompt to check. - :returns: A dictionary containing the resized prompt and additional information. - """ - return self.prompt_handler(prompt) - - def _extract_messages_from_response(self, response_body: Dict[str, Any]) -> List[ChatMessage]: - """ - Extracts the messages from the response body. - - :param response_body: The response body. - :return: The extracted ChatMessage list. - """ - messages: List[ChatMessage] = [] - responses = response_body.get("outputs", []) - for response in responses: - meta = {k: v for k, v in response.items() if k not in ["text"]} - messages.append(ChatMessage.from_assistant(response["text"], meta=meta)) - return messages - - def _build_streaming_chunk(self, chunk: Dict[str, Any]) -> StreamingChunk: - """ - Extracts the content and meta from a streaming chunk. - - :param chunk: The streaming chunk as dict. - :returns: A StreamingChunk object. - """ - response_chunk = chunk.get("outputs", []) - if response_chunk: - return StreamingChunk(content=response_chunk[0].get("text", ""), meta=chunk) - return StreamingChunk(content="", meta=chunk) - - -class MetaLlama2ChatAdapter(BedrockModelChatAdapter): - """ - Model adapter for the Meta Llama 2 models. - """ - - # https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-meta.html - ALLOWED_PARAMS: ClassVar[List[str]] = ["max_gen_len", "temperature", "top_p"] - - chat_template = ( - "{% if messages[0]['role'] == 'system' %}" - "{% set loop_messages = messages[1:] %}" - "{% set system_message = messages[0]['content'] %}" - "{% else %}" - "{% set loop_messages = messages %}" - "{% set system_message = false %}" - "{% endif %}" - "{% for message in loop_messages %}" - "{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}" - "{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}" - "{% endif %}" - "{% if loop.index0 == 0 and system_message != false %}" - "{% set content = '<>\n' + system_message + '\n<>\n\n' + message['content'] %}" - "{% else %}" - "{% set content = message['content'] %}" - "{% endif %}" - "{% if message['role'] == 'user' %}" - "{{ bos_token + '[INST] ' + content.strip() + ' [/INST]' }}" - "{% elif message['role'] == 'assistant' %}" - "{{ ' ' + content.strip() + ' ' + eos_token }}" - "{% endif %}" - "{% endfor %}" - ) - - def __init__(self, truncate: Optional[bool], generation_kwargs: Dict[str, Any]) -> None: - """ - Initializes the Meta Llama 2 chat adapter. - :param truncate: Whether to truncate the prompt if it exceeds the model's max token limit. - :param generation_kwargs: The generation kwargs. - """ - super().__init__(truncate, generation_kwargs) - # We pop the model_max_length as it is not sent to the model - # but used to truncate the prompt if needed - # Llama 2 has context window size of 4096 tokens - # with some exceptions when the context window has been extended - model_max_length = self.generation_kwargs.pop("model_max_length", 4096) - - # Use `google/flan-t5-base` as it's also BPE sentencepiece tokenizer just like llama 2 - # a) we should get good estimates for the prompt length (empirically close to llama 2) - # b) we can use apply_chat_template with the template above to delineate ChatMessages - tokenizer: PreTrainedTokenizer = AutoTokenizer.from_pretrained("google/flan-t5-base") - tokenizer.bos_token = "" - tokenizer.eos_token = "" - tokenizer.unk_token = "" - self.prompt_handler = DefaultPromptHandler( - tokenizer=tokenizer, - model_max_length=model_max_length, - max_length=self.generation_kwargs.get("max_gen_len") or 512, - ) - - def prepare_body(self, messages: List[ChatMessage], **inference_kwargs) -> Dict[str, Any]: - """ - Prepares the body for the Meta Llama 2 request. - - :param messages: The chat messages to package into the request. - :param inference_kwargs: Additional inference kwargs to use. - """ - default_params = {"max_gen_len": self.generation_kwargs.get("max_gen_len") or 512} - - # no support for stop words in Meta Llama 2 - params = self._get_params(inference_kwargs, default_params, self.ALLOWED_PARAMS) - body = {"prompt": self.prepare_chat_messages(messages=messages), **params} - return body - - def prepare_chat_messages(self, messages: List[ChatMessage]) -> str: - """ - Prepares the chat messages for the Meta Llama 2 request. - - :param messages: The chat messages to prepare. - :returns: The prepared chat messages as a string ready for the model. - """ - prepared_prompt: str = self.prompt_handler.tokenizer.apply_chat_template( - conversation=messages, tokenize=False, chat_template=self.chat_template - ) - - if self.truncate: - prepared_prompt = self._ensure_token_limit(prepared_prompt) - return prepared_prompt - - def check_prompt(self, prompt: str) -> Dict[str, Any]: - """ - Checks the prompt length and resizes it if necessary. If the prompt is too long, it will be truncated. - - :param prompt: The prompt to check. - :returns: A dictionary containing the resized prompt and additional information. - - """ - return self.prompt_handler(prompt) - - def _extract_messages_from_response(self, response_body: Dict[str, Any]) -> List[ChatMessage]: - """ - Extracts the messages from the response body. - - :param response_body: The response body. - :return: The extracted ChatMessage list. - """ - message_tag = "generation" - metadata = {k: v for (k, v) in response_body.items() if k != message_tag} - return [ChatMessage.from_assistant(response_body[message_tag], meta=metadata)] - - def _build_streaming_chunk(self, chunk: Dict[str, Any]) -> StreamingChunk: - """ - Extracts the content and meta from a streaming chunk. - - :param chunk: The streaming chunk as dict. - :returns: A StreamingChunk object. - """ - return StreamingChunk(content=chunk.get("generation", ""), meta=chunk) diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py index 183198bce..499fe1c24 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py @@ -1,12 +1,12 @@ import json import logging -import re -from typing import Any, Callable, ClassVar, Dict, List, Optional, Type +from typing import Any, Callable, Dict, List, Optional from botocore.config import Config +from botocore.eventstream import EventStream from botocore.exceptions import ClientError from haystack import component, default_from_dict, default_to_dict -from haystack.dataclasses import ChatMessage, StreamingChunk +from haystack.dataclasses import ChatMessage, ChatRole, StreamingChunk from haystack.utils.auth import Secret, deserialize_secrets_inplace from haystack.utils.callable_serialization import deserialize_callable, serialize_callable @@ -16,18 +16,16 @@ ) from haystack_integrations.common.amazon_bedrock.utils import get_aws_session -from .adapters import AnthropicClaudeChatAdapter, BedrockModelChatAdapter, MetaLlama2ChatAdapter, MistralChatAdapter - logger = logging.getLogger(__name__) @component class AmazonBedrockChatGenerator: """ - Completes chats using LLMs hosted on Amazon Bedrock. + Completes chats using LLMs hosted on Amazon Bedrock available via the Bedrock Converse API. For example, to use the Anthropic Claude 3 Sonnet model, initialize this component with the - 'anthropic.claude-3-sonnet-20240229-v1:0' model name. + 'anthropic.claude-3-5-sonnet-20240620-v1:0' model name. ### Usage example @@ -40,7 +38,7 @@ class AmazonBedrockChatGenerator: ChatMessage.from_user("What's Natural Language Processing?")] - client = AmazonBedrockChatGenerator(model="anthropic.claude-3-sonnet-20240229-v1:0", + client = AmazonBedrockChatGenerator(model="anthropic.claude-3-5-sonnet-20240620-v1:0", streaming_callback=print_streaming_chunk) client.run(messages, generation_kwargs={"max_tokens": 512}) @@ -58,12 +56,6 @@ class AmazonBedrockChatGenerator: supports Amazon Bedrock. """ - SUPPORTED_MODEL_PATTERNS: ClassVar[Dict[str, Type[BedrockModelChatAdapter]]] = { - r"([a-z]{2}\.)?anthropic.claude.*": AnthropicClaudeChatAdapter, - r"([a-z]{2}\.)?meta.llama2.*": MetaLlama2ChatAdapter, - r"([a-z]{2}\.)?mistral.*": MistralChatAdapter, - } - def __init__( self, model: str, @@ -77,7 +69,6 @@ def __init__( generation_kwargs: Optional[Dict[str, Any]] = None, stop_words: Optional[List[str]] = None, streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, - truncate: Optional[bool] = True, boto3_config: Optional[Dict[str, Any]] = None, ): """ @@ -111,7 +102,6 @@ def __init__( function that handles the streaming chunks. The callback function receives a [StreamingChunk](https://docs.haystack.deepset.ai/docs/data-classes#streamingchunk) object and switches the streaming mode on. - :param truncate: Whether to truncate the prompt messages or not. :param boto3_config: The configuration for the boto3 client. :raises ValueError: If the model name is empty or None. @@ -129,17 +119,8 @@ def __init__( self.aws_profile_name = aws_profile_name self.stop_words = stop_words or [] self.streaming_callback = streaming_callback - self.truncate = truncate self.boto3_config = boto3_config - # get the model adapter for the given model - model_adapter_cls = self.get_model_adapter(model=model) - if not model_adapter_cls: - msg = f"AmazonBedrockGenerator doesn't support the model {model}." - raise AmazonBedrockConfigurationError(msg) - self.model_adapter = model_adapter_cls(self.truncate, generation_kwargs or {}) - - # create the AWS session and client def resolve_secret(secret: Optional[Secret]) -> Optional[str]: return secret.resolve_value() if secret else None @@ -162,89 +143,9 @@ def resolve_secret(secret: Optional[Secret]) -> Optional[str]: ) raise AmazonBedrockConfigurationError(msg) from exception - @component.output_types(replies=List[ChatMessage]) - def run( - self, - messages: List[ChatMessage], - streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, - generation_kwargs: Optional[Dict[str, Any]] = None, - ): - """ - Generates a list of `ChatMessage` response to the given messages using the Amazon Bedrock LLM. - - :param messages: The messages to generate a response to. - :param streaming_callback: - A callback function that is called when a new token is received from the stream. - :param generation_kwargs: Additional generation keyword arguments passed to the model. - :returns: A dictionary with the following keys: - - `replies`: The generated List of `ChatMessage` objects. - """ - generation_kwargs = generation_kwargs or {} - generation_kwargs = generation_kwargs.copy() - - streaming_callback = streaming_callback or self.streaming_callback - generation_kwargs["stream"] = streaming_callback is not None - - # check if the prompt is a list of ChatMessage objects - if not ( - isinstance(messages, list) - and len(messages) > 0 - and all(isinstance(message, ChatMessage) for message in messages) - ): - msg = f"The model {self.model} requires a list of ChatMessage objects as a prompt." - raise ValueError(msg) - - body = self.model_adapter.prepare_body( - messages=messages, **{"stop_words": self.stop_words, **generation_kwargs} - ) - try: - if streaming_callback: - response = self.client.invoke_model_with_response_stream( - body=json.dumps(body), modelId=self.model, accept="application/json", contentType="application/json" - ) - response_stream = response["body"] - replies = self.model_adapter.get_stream_responses( - stream=response_stream, streaming_callback=streaming_callback - ) - else: - response = self.client.invoke_model( - body=json.dumps(body), modelId=self.model, accept="application/json", contentType="application/json" - ) - response_body = json.loads(response.get("body").read().decode("utf-8")) - replies = self.model_adapter.get_responses(response_body=response_body) - except ClientError as exception: - msg = f"Could not inference Amazon Bedrock model {self.model} due: {exception}" - raise AmazonBedrockInferenceError(msg) from exception - - # rename the meta key to be inline with OpenAI meta output keys - for response in replies: - if response.meta: - if "usage" in response.meta: - if "input_tokens" in response.meta["usage"]: - response.meta["usage"]["prompt_tokens"] = response.meta["usage"].pop("input_tokens") - if "output_tokens" in response.meta["usage"]: - response.meta["usage"]["completion_tokens"] = response.meta["usage"].pop("output_tokens") - else: - response.meta["usage"] = {} - if "prompt_token_count" in response.meta: - response.meta["usage"]["prompt_tokens"] = response.meta.pop("prompt_token_count") - if "generation_token_count" in response.meta: - response.meta["usage"]["completion_tokens"] = response.meta.pop("generation_token_count") - - return {"replies": replies} - - @classmethod - def get_model_adapter(cls, model: str) -> Optional[Type[BedrockModelChatAdapter]]: - """ - Returns the model adapter for the given model. - - :param model: The model to get the adapter for. - :returns: The model adapter for the given model, or None if the model is not supported. - """ - for pattern, adapter in cls.SUPPORTED_MODEL_PATTERNS.items(): - if re.fullmatch(pattern, model): - return adapter - return None + self.generation_kwargs = generation_kwargs or {} + self.stop_words = stop_words or [] + self.streaming_callback = streaming_callback def to_dict(self) -> Dict[str, Any]: """ @@ -263,9 +164,8 @@ def to_dict(self) -> Dict[str, Any]: aws_profile_name=self.aws_profile_name.to_dict() if self.aws_profile_name else None, model=self.model, stop_words=self.stop_words, - generation_kwargs=self.model_adapter.generation_kwargs, + generation_kwargs=self.generation_kwargs, streaming_callback=callback_name, - truncate=self.truncate, boto3_config=self.boto3_config, ) @@ -274,10 +174,9 @@ def from_dict(cls, data: Dict[str, Any]) -> "AmazonBedrockChatGenerator": """ Deserializes the component from a dictionary. - :param data: - Dictionary to deserialize from. + :param data: Dictionary with serialized data. :returns: - Deserialized component. + Instance of `AmazonBedrockChatGenerator`. """ init_params = data.get("init_parameters", {}) serialized_callback_handler = init_params.get("streaming_callback") @@ -288,3 +187,172 @@ def from_dict(cls, data: Dict[str, Any]) -> "AmazonBedrockChatGenerator": ["aws_access_key_id", "aws_secret_access_key", "aws_session_token", "aws_region_name", "aws_profile_name"], ) return default_from_dict(cls, data) + + @component.output_types(replies=List[ChatMessage]) + def run( + self, + messages: List[ChatMessage], + streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, + generation_kwargs: Optional[Dict[str, Any]] = None, + ): + generation_kwargs = generation_kwargs or {} + + # Merge generation_kwargs with defaults + merged_kwargs = self.generation_kwargs.copy() + merged_kwargs.update(generation_kwargs) + + # Extract known inference parameters + # See https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_InferenceConfiguration.html + inference_config = { + key: merged_kwargs.pop(key, None) + for key in ["maxTokens", "stopSequences", "temperature", "topP"] + if key in merged_kwargs + } + + # Extract tool configuration if present + # See https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ToolConfiguration.html + tool_config = merged_kwargs.pop("toolConfig", None) + + # Any remaining kwargs go to additionalModelRequestFields + additional_fields = merged_kwargs if merged_kwargs else None + + # Prepare system prompts and messages + system_prompts = [] + if messages and messages[0].is_from(ChatRole.SYSTEM): + system_prompts = [{"text": messages[0].text}] + messages = messages[1:] + + messages_list = [{"role": msg.role.value, "content": [{"text": msg.text}]} for msg in messages] + + # Build API parameters + params = { + "modelId": self.model, + "messages": messages_list, + "system": system_prompts, + "inferenceConfig": inference_config, + } + if tool_config: + params["toolConfig"] = tool_config + if additional_fields: + params["additionalModelRequestFields"] = additional_fields + + callback = streaming_callback or self.streaming_callback + + try: + if callback: + response = self.client.converse_stream(**params) + response_stream: EventStream = response.get("stream") + if not response_stream: + msg = "No stream found in the response." + raise AmazonBedrockInferenceError(msg) + replies = self.process_streaming_response(response_stream, callback) + else: + response = self.client.converse(**params) + replies = self.extract_replies_from_response(response) + except ClientError as exception: + msg = f"Could not generate inference for Amazon Bedrock model {self.model} due: {exception}" + raise AmazonBedrockInferenceError(msg) from exception + + return {"replies": replies} + + def extract_replies_from_response(self, response_body: Dict[str, Any]) -> List[ChatMessage]: + replies = [] + if "output" in response_body and "message" in response_body["output"]: + message = response_body["output"]["message"] + if message["role"] == "assistant": + content_blocks = message["content"] + + # Common meta information + base_meta = { + "model": self.model, + "index": 0, + "finish_reason": response_body.get("stopReason"), + "usage": { + # OpenAI's format for usage for cross ChatGenerator compatibility + "prompt_tokens": response_body.get("usage", {}).get("inputTokens", 0), + "completion_tokens": response_body.get("usage", {}).get("outputTokens", 0), + "total_tokens": response_body.get("usage", {}).get("totalTokens", 0), + }, + } + + # Process each content block separately + for content_block in content_blocks: + if "text" in content_block: + replies.append(ChatMessage.from_assistant(content=content_block["text"], meta=base_meta.copy())) + elif "toolUse" in content_block: + replies.append( + ChatMessage.from_assistant( + content=json.dumps(content_block["toolUse"]), meta=base_meta.copy() + ) + ) + return replies + + def process_streaming_response( + self, response_stream: EventStream, streaming_callback: Callable[[StreamingChunk], None] + ) -> List[ChatMessage]: + replies = [] + current_content = "" + current_tool_use = None + base_meta = { + "model": self.model, + "index": 0, + } + + for event in response_stream: + if "contentBlockStart" in event: + # Reset accumulators for new message + current_content = "" + current_tool_use = None + block_start = event["contentBlockStart"] + if "start" in block_start and "toolUse" in block_start["start"]: + tool_start = block_start["start"]["toolUse"] + current_tool_use = { + "toolUseId": tool_start["toolUseId"], + "name": tool_start["name"], + "input": "", # Will accumulate deltas as string + } + + elif "contentBlockDelta" in event: + delta = event["contentBlockDelta"]["delta"] + if "text" in delta: + delta_text = delta["text"] + current_content += delta_text + streaming_chunk = StreamingChunk(content=delta_text, meta=None) + # it only makes sense to call callback on text deltas + streaming_callback(streaming_chunk) + elif "toolUse" in delta and current_tool_use: + # Accumulate tool use input deltas + current_tool_use["input"] += delta["toolUse"].get("input", "") + elif "contentBlockStop" in event: + if current_tool_use: + # Parse accumulated input if it's a JSON string + try: + input_json = json.loads(current_tool_use["input"]) + current_tool_use["input"] = input_json + except json.JSONDecodeError: + # Keep as string if not valid JSON + pass + + tool_content = json.dumps(current_tool_use) + replies.append(ChatMessage.from_assistant(content=tool_content, meta=base_meta.copy())) + elif current_content: + replies.append(ChatMessage.from_assistant(content=current_content, meta=base_meta.copy())) + + elif "messageStop" in event: + # not 100% correct for multiple messages but no way around it + for reply in replies: + reply.meta["finish_reason"] = event["messageStop"].get("stopReason") + + elif "metadata" in event: + metadata = event["metadata"] + # not 100% correct for multiple messages but no way around it + for reply in replies: + if "usage" in metadata: + usage = metadata["usage"] + reply.meta["usage"] = { + "prompt_tokens": usage.get("inputTokens", 0), + "completion_tokens": usage.get("outputTokens", 0), + "total_tokens": usage.get("totalTokens", 0), + } + + return replies diff --git a/integrations/amazon_bedrock/tests/test_chat_generator.py b/integrations/amazon_bedrock/tests/test_chat_generator.py index 8d6a5c3ee..8eb29729c 100644 --- a/integrations/amazon_bedrock/tests/test_chat_generator.py +++ b/integrations/amazon_bedrock/tests/test_chat_generator.py @@ -1,30 +1,36 @@ import json -import logging -import os -from typing import Any, Dict, Optional, Type -from unittest.mock import MagicMock, patch +from typing import Any, Dict, Optional import pytest from haystack.components.generators.utils import print_streaming_chunk from haystack.dataclasses import ChatMessage, ChatRole, StreamingChunk from haystack_integrations.components.generators.amazon_bedrock import AmazonBedrockChatGenerator -from haystack_integrations.components.generators.amazon_bedrock.chat.adapters import ( - AnthropicClaudeChatAdapter, - BedrockModelChatAdapter, - MetaLlama2ChatAdapter, - MistralChatAdapter, -) KLASS = "haystack_integrations.components.generators.amazon_bedrock.chat.chat_generator.AmazonBedrockChatGenerator" -MODELS_TO_TEST = ["anthropic.claude-3-sonnet-20240229-v1:0", "anthropic.claude-v2:1"] -MODELS_TO_TEST_WITH_TOOLS = ["anthropic.claude-3-haiku-20240307-v1:0"] -MISTRAL_MODELS = [ - "mistral.mistral-7b-instruct-v0:2", - "mistral.mixtral-8x7b-instruct-v0:1", +MODELS_TO_TEST = [ + "anthropic.claude-3-5-sonnet-20240620-v1:0", + "cohere.command-r-plus-v1:0", + "mistral.mistral-large-2402-v1:0", +] +MODELS_TO_TEST_WITH_TOOLS = [ + "anthropic.claude-3-5-sonnet-20240620-v1:0", + "cohere.command-r-plus-v1:0", "mistral.mistral-large-2402-v1:0", ] +# so far we've discovered these models support streaming and tool use +STREAMING_TOOL_MODELS = ["anthropic.claude-3-5-sonnet-20240620-v1:0", "cohere.command-r-plus-v1:0"] + + +@pytest.fixture +def chat_messages(): + messages = [ + ChatMessage.from_system("\\nYou are a helpful assistant, be super brief in your responses."), + ChatMessage.from_user("What's the capital of France?"), + ] + return messages + @pytest.mark.parametrize( "boto3_config", @@ -35,12 +41,12 @@ }, ], ) -def test_to_dict(mock_boto3_session: Any, boto3_config: Optional[Dict[str, Any]]): +def test_to_dict(mock_boto3_session, boto3_config): """ Test that the to_dict method returns the correct dictionary without aws credentials """ generator = AmazonBedrockChatGenerator( - model="anthropic.claude-v2", + model="cohere.command-r-plus-v1:0", generation_kwargs={"temperature": 0.7}, streaming_callback=print_streaming_chunk, boto3_config=boto3_config, @@ -53,11 +59,10 @@ def test_to_dict(mock_boto3_session: Any, boto3_config: Optional[Dict[str, Any]] "aws_session_token": {"type": "env_var", "env_vars": ["AWS_SESSION_TOKEN"], "strict": False}, "aws_region_name": {"type": "env_var", "env_vars": ["AWS_DEFAULT_REGION"], "strict": False}, "aws_profile_name": {"type": "env_var", "env_vars": ["AWS_PROFILE"], "strict": False}, - "model": "anthropic.claude-v2", + "model": "cohere.command-r-plus-v1:0", "generation_kwargs": {"temperature": 0.7}, "stop_words": [], "streaming_callback": "haystack.components.generators.utils.print_streaming_chunk", - "truncate": True, "boto3_config": boto3_config, }, } @@ -87,16 +92,14 @@ def test_from_dict(mock_boto3_session: Any, boto3_config: Optional[Dict[str, Any "aws_session_token": {"type": "env_var", "env_vars": ["AWS_SESSION_TOKEN"], "strict": False}, "aws_region_name": {"type": "env_var", "env_vars": ["AWS_DEFAULT_REGION"], "strict": False}, "aws_profile_name": {"type": "env_var", "env_vars": ["AWS_PROFILE"], "strict": False}, - "model": "anthropic.claude-v2", + "model": "anthropic.claude-3-5-sonnet-20240620-v1:0", "generation_kwargs": {"temperature": 0.7}, "streaming_callback": "haystack.components.generators.utils.print_streaming_chunk", - "truncate": True, "boto3_config": boto3_config, }, } ) - assert generator.model == "anthropic.claude-v2" - assert generator.model_adapter.generation_kwargs == {"temperature": 0.7} + assert generator.model == "anthropic.claude-3-5-sonnet-20240620-v1:0" assert generator.streaming_callback == print_streaming_chunk assert generator.boto3_config == boto3_config @@ -107,13 +110,10 @@ def test_default_constructor(mock_boto3_session, set_env_variables): """ layer = AmazonBedrockChatGenerator( - model="anthropic.claude-v2", + model="anthropic.claude-3-5-sonnet-20240620-v1:0", ) - assert layer.model == "anthropic.claude-v2" - assert layer.truncate is True - assert layer.model_adapter.prompt_handler is not None - assert layer.model_adapter.prompt_handler.model_max_length == 100000 + assert layer.model == "anthropic.claude-3-5-sonnet-20240620-v1:0" # assert mocked boto3 client called exactly once mock_boto3_session.assert_called_once() @@ -134,18 +134,10 @@ def test_constructor_with_generation_kwargs(mock_boto3_session): """ generation_kwargs = {"temperature": 0.7} - layer = AmazonBedrockChatGenerator(model="anthropic.claude-v2", generation_kwargs=generation_kwargs) - assert "temperature" in layer.model_adapter.generation_kwargs - assert layer.model_adapter.generation_kwargs["temperature"] == 0.7 - assert layer.model_adapter.truncate is True - - -def test_constructor_with_truncate(mock_boto3_session): - """ - Test that truncate param is correctly set in the model constructor - """ - layer = AmazonBedrockChatGenerator(model="anthropic.claude-v2", truncate=False) - assert layer.model_adapter.truncate is False + layer = AmazonBedrockChatGenerator( + model="anthropic.claude-3-5-sonnet-20240620-v1:0", generation_kwargs=generation_kwargs + ) + assert layer.generation_kwargs == generation_kwargs def test_constructor_with_empty_model(): @@ -156,208 +148,15 @@ def test_constructor_with_empty_model(): AmazonBedrockChatGenerator(model="") -def test_short_prompt_is_not_truncated(mock_boto3_session): - """ - Test that a short prompt is not truncated - """ - # Define a short mock prompt and its tokenized version - mock_prompt_text = "I am a tokenized prompt" - mock_prompt_tokens = mock_prompt_text.split() - - # Mock the tokenizer so it returns our predefined tokens - mock_tokenizer = MagicMock() - mock_tokenizer.tokenize.return_value = mock_prompt_tokens - - # We set a small max_length for generated text (3 tokens) and a total model_max_length of 10 tokens - # Since our mock prompt is 5 tokens long, it doesn't exceed the - # total limit (5 prompt tokens + 3 generated tokens < 10 tokens) - max_length_generated_text = 3 - total_model_max_length = 10 - - with patch("transformers.AutoTokenizer.from_pretrained", return_value=mock_tokenizer): - layer = AmazonBedrockChatGenerator( - "anthropic.claude-v2", - generation_kwargs={"model_max_length": total_model_max_length, "max_tokens": max_length_generated_text}, - ) - prompt_after_resize = layer.model_adapter._ensure_token_limit(mock_prompt_text) - - # The prompt doesn't exceed the limit, _ensure_token_limit doesn't truncate it - assert prompt_after_resize == mock_prompt_text - - -def test_long_prompt_is_truncated(mock_boto3_session): - """ - Test that a long prompt is truncated - """ - # Define a long mock prompt and its tokenized version - long_prompt_text = "I am a tokenized prompt of length eight" - long_prompt_tokens = long_prompt_text.split() - - # _ensure_token_limit will truncate the prompt to make it fit into the model's max token limit - truncated_prompt_text = "I am a tokenized prompt of length" - - # Mock the tokenizer to return our predefined tokens - # convert tokens to our predefined truncated text - mock_tokenizer = MagicMock() - mock_tokenizer.tokenize.return_value = long_prompt_tokens - mock_tokenizer.convert_tokens_to_string.return_value = truncated_prompt_text - - # We set a small max_length for generated text (3 tokens) and a total model_max_length of 10 tokens - # Our mock prompt is 8 tokens long, so it exceeds the total limit (8 prompt tokens + 3 generated tokens > 10 tokens) - max_length_generated_text = 3 - total_model_max_length = 10 - - with patch("transformers.AutoTokenizer.from_pretrained", return_value=mock_tokenizer): - layer = AmazonBedrockChatGenerator( - "anthropic.claude-v2", - generation_kwargs={"model_max_length": total_model_max_length, "max_tokens": max_length_generated_text}, - ) - prompt_after_resize = layer.model_adapter._ensure_token_limit(long_prompt_text) +class TestAmazonBedrockChatGeneratorInference: - # The prompt exceeds the limit, _ensure_token_limit truncates it - assert prompt_after_resize == truncated_prompt_text - - -def test_long_prompt_is_not_truncated_when_truncate_false(mock_boto3_session): - """ - Test that a long prompt is not truncated and _ensure_token_limit is not called when truncate is set to False - """ - messages = [ChatMessage.from_user("What is the biggest city in United States?")] - - # Our mock prompt is 8 tokens long, so it exceeds the total limit (8 prompt tokens + 3 generated tokens > 10 tokens) - max_length_generated_text = 3 - total_model_max_length = 10 - - with patch("transformers.AutoTokenizer.from_pretrained", return_value=MagicMock()): - generator = AmazonBedrockChatGenerator( - model="anthropic.claude-v2", - truncate=False, - generation_kwargs={"model_max_length": total_model_max_length, "max_tokens": max_length_generated_text}, - ) - - # Mock the _ensure_token_limit method to track if it is called - with patch.object( - generator.model_adapter, "_ensure_token_limit", wraps=generator.model_adapter._ensure_token_limit - ) as mock_ensure_token_limit: - # Mock the model adapter to avoid actual invocation - generator.model_adapter.prepare_body = MagicMock(return_value={}) - generator.client = MagicMock() - generator.client.invoke_model = MagicMock( - return_value={"body": MagicMock(read=MagicMock(return_value=b'{"generated_text": "response"}'))} - ) - - generator.model_adapter.get_responses = MagicMock( - return_value=[ - ChatMessage.from_assistant( - content="Some text", - meta={ - "model": "claude-3-sonnet-20240229", - "index": 0, - "finish_reason": "end_turn", - "usage": {"prompt_tokens": 16, "completion_tokens": 55}, - }, - ) - ] - ) - # Invoke the generator - generator.run(messages=messages) - - # Ensure _ensure_token_limit was not called - mock_ensure_token_limit.assert_not_called() - - # Check the prompt passed to prepare_body - generator.model_adapter.prepare_body.assert_called_with(messages=messages, stop_words=[], stream=False) - - -@pytest.mark.parametrize( - "model, expected_model_adapter", - [ - ("anthropic.claude-v1", AnthropicClaudeChatAdapter), - ("anthropic.claude-v2", AnthropicClaudeChatAdapter), - ("eu.anthropic.claude-v1", AnthropicClaudeChatAdapter), # cross-region inference - ("us.anthropic.claude-v2", AnthropicClaudeChatAdapter), # cross-region inference - ("anthropic.claude-instant-v1", AnthropicClaudeChatAdapter), - ("anthropic.claude-super-v5", AnthropicClaudeChatAdapter), # artificial - ("meta.llama2-13b-chat-v1", MetaLlama2ChatAdapter), - ("meta.llama2-70b-chat-v1", MetaLlama2ChatAdapter), - ("meta.llama2-130b-v5", MetaLlama2ChatAdapter), # artificial - ("us.meta.llama2-13b-chat-v1", MetaLlama2ChatAdapter), # cross-region inference - ("eu.meta.llama2-70b-chat-v1", MetaLlama2ChatAdapter), # cross-region inference - ("de.meta.llama2-130b-v5", MetaLlama2ChatAdapter), # cross-region inference - ("unknown_model", None), - ], -) -def test_get_model_adapter(model: str, expected_model_adapter: Optional[Type[BedrockModelChatAdapter]]): - """ - Test that the correct model adapter is returned for a given model - """ - model_adapter = AmazonBedrockChatGenerator.get_model_adapter(model=model) - assert model_adapter == expected_model_adapter - - -class TestAnthropicClaudeAdapter: - def test_prepare_body_with_default_params(self) -> None: - layer = AnthropicClaudeChatAdapter(truncate=True, generation_kwargs={}) - prompt = "Hello, how are you?" - expected_body = { - "anthropic_version": "bedrock-2023-05-31", - "max_tokens": 512, - "messages": [{"content": [{"text": "Hello, how are you?", "type": "text"}], "role": "user"}], - } - - body = layer.prepare_body([ChatMessage.from_user(prompt)]) - - assert body == expected_body - - def test_prepare_body_with_custom_inference_params(self) -> None: - layer = AnthropicClaudeChatAdapter( - truncate=True, generation_kwargs={"temperature": 0.7, "top_p": 0.8, "top_k": 4} - ) - prompt = "Hello, how are you?" - expected_body = { - "anthropic_version": "bedrock-2023-05-31", - "max_tokens": 512, - "messages": [{"content": [{"text": "Hello, how are you?", "type": "text"}], "role": "user"}], - "stop_sequences": ["CUSTOM_STOP"], - "temperature": 0.7, - "top_k": 5, - "top_p": 0.8, - } - - body = layer.prepare_body( - [ChatMessage.from_user(prompt)], top_p=0.8, top_k=5, max_tokens_to_sample=69, stop_sequences=["CUSTOM_STOP"] - ) - - assert body == expected_body - - @pytest.mark.parametrize("model_name", MODELS_TO_TEST_WITH_TOOLS) + @pytest.mark.parametrize("model_name", MODELS_TO_TEST) @pytest.mark.integration - def test_tools_use(self, model_name): - """ - Test function calling with AWS Bedrock Anthropic adapter - """ - # See https://docs.anthropic.com/en/docs/tool-use for more information - tools = [ - { - "name": "top_song", - "description": "Get the most popular song played on a radio station.", - "input_schema": { - "type": "object", - "properties": { - "sign": { - "type": "string", - "description": "The call sign for the radio station for which you want the most popular" - " song. Example calls signs are WZPZ and WKRP.", - } - }, - "required": ["sign"], - }, - } - ] - messages = [] - messages.append(ChatMessage.from_user("What is the most popular song on WZPZ?")) + def test_default_inference_params(self, model_name, chat_messages): client = AmazonBedrockChatGenerator(model=model_name) - response = client.run(messages=messages, generation_kwargs={"tools": tools, "tool_choice": {"type": "any"}}) + response = client.run(chat_messages) + + assert "replies" in response, "Response does not contain 'replies' key" replies = response["replies"] assert isinstance(replies, list), "Replies is not a list" assert len(replies) > 0, "No replies received" @@ -366,102 +165,32 @@ def test_tools_use(self, model_name): assert isinstance(first_reply, ChatMessage), "First reply is not a ChatMessage instance" assert first_reply.content, "First reply has no content" assert ChatMessage.is_from(first_reply, ChatRole.ASSISTANT), "First reply is not from the assistant" - assert "top_song" in first_reply.content.lower(), "First reply does not contain top_song" + assert "paris" in first_reply.content.lower(), "First reply does not contain 'paris'" assert first_reply.meta, "First reply has no metadata" - fc_response = json.loads(first_reply.content) - assert "name" in fc_response, "First reply does not contain name of the tool" - assert "input" in fc_response, "First reply does not contain input of the tool" - - -class TestMistralAdapter: - def test_prepare_body_with_default_params(self) -> None: - layer = MistralChatAdapter(truncate=True, generation_kwargs={}) - prompt = "Hello, how are you?" - expected_body = { - "max_tokens": 512, - "prompt": "[INST] Hello, how are you? [/INST]", - } - body = layer.prepare_body([ChatMessage.from_user(prompt)]) + if first_reply.meta and "usage" in first_reply.meta: + assert "prompt_tokens" in first_reply.meta["usage"] + assert "completion_tokens" in first_reply.meta["usage"] - assert body == expected_body + @pytest.mark.parametrize("model_name", MODELS_TO_TEST) + @pytest.mark.integration + def test_default_inference_with_streaming(self, model_name, chat_messages): + streaming_callback_called = False + paris_found_in_response = False - def test_prepare_body_with_custom_inference_params(self) -> None: - layer = MistralChatAdapter(truncate=True, generation_kwargs={"temperature": 0.7, "top_p": 0.8, "top_k": 4}) - prompt = "Hello, how are you?" - expected_body = { - "prompt": "[INST] Hello, how are you? [/INST]", - "max_tokens": 512, - "temperature": 0.7, - "top_p": 0.8, - } + def streaming_callback(chunk: StreamingChunk): + nonlocal streaming_callback_called, paris_found_in_response + streaming_callback_called = True + assert isinstance(chunk, StreamingChunk) + assert chunk.content is not None + if not paris_found_in_response: + paris_found_in_response = "paris" in chunk.content.lower() - body = layer.prepare_body([ChatMessage.from_user(prompt)], top_p=0.8, top_k=5, max_tokens_to_sample=69) - - assert body == expected_body - - def test_mistral_chat_template_correct_order(self): - layer = MistralChatAdapter(truncate=True, generation_kwargs={}) - layer.prepare_body([ChatMessage.from_user("A"), ChatMessage.from_assistant("B"), ChatMessage.from_user("C")]) - layer.prepare_body([ChatMessage.from_system("A"), ChatMessage.from_user("B"), ChatMessage.from_assistant("C")]) - - def test_mistral_chat_template_incorrect_order(self): - layer = MistralChatAdapter(truncate=True, generation_kwargs={}) - try: - layer.prepare_body([ChatMessage.from_assistant("B"), ChatMessage.from_assistant("C")]) - msg = "Expected TemplateError" - raise AssertionError(msg) - except Exception as e: - assert "Conversation roles must alternate user/assistant/" in str(e) - - try: - layer.prepare_body([ChatMessage.from_user("A"), ChatMessage.from_user("B")]) - msg = "Expected TemplateError" - raise AssertionError(msg) - except Exception as e: - assert "Conversation roles must alternate user/assistant/" in str(e) - - try: - layer.prepare_body([ChatMessage.from_system("A"), ChatMessage.from_system("B")]) - msg = "Expected TemplateError" - raise AssertionError(msg) - except Exception as e: - assert "Conversation roles must alternate user/assistant/" in str(e) - - def test_use_mistral_adapter_without_hf_token(self, monkeypatch, caplog) -> None: - monkeypatch.delenv("HF_TOKEN", raising=False) - with ( - patch("transformers.AutoTokenizer.from_pretrained") as mock_pretrained, - patch("haystack_integrations.components.generators.amazon_bedrock.chat.adapters.DefaultPromptHandler"), - caplog.at_level(logging.WARNING), - ): - MistralChatAdapter(truncate=True, generation_kwargs={}) - mock_pretrained.assert_called_with("NousResearch/Llama-2-7b-chat-hf") - assert "no HF_TOKEN was found" in caplog.text - - def test_use_mistral_adapter_with_hf_token(self, monkeypatch) -> None: - monkeypatch.setenv("HF_TOKEN", "test") - with ( - patch("transformers.AutoTokenizer.from_pretrained") as mock_pretrained, - patch("haystack_integrations.components.generators.amazon_bedrock.chat.adapters.DefaultPromptHandler"), - ): - MistralChatAdapter(truncate=True, generation_kwargs={}) - mock_pretrained.assert_called_with("mistralai/Mistral-7B-Instruct-v0.1") - - @pytest.mark.skipif( - not os.environ.get("HF_API_TOKEN", None), - reason=( - "To run this test, you need to set the HF_API_TOKEN environment variable. The associated account must also " - "have requested access to the gated model `mistralai/Mistral-7B-Instruct-v0.1`" - ), - ) - @pytest.mark.parametrize("model_name", MISTRAL_MODELS) - @pytest.mark.integration - def test_default_inference_params(self, model_name, chat_messages): - client = AmazonBedrockChatGenerator(model=model_name) + client = AmazonBedrockChatGenerator(model=model_name, streaming_callback=streaming_callback) response = client.run(chat_messages) - assert "replies" in response, "Response does not contain 'replies' key" + assert streaming_callback_called, "Streaming callback was not called" + assert paris_found_in_response, "The streaming callback response did not contain 'paris'" replies = response["replies"] assert isinstance(replies, list), "Replies is not a list" assert len(replies) > 0, "No replies received" @@ -473,77 +202,44 @@ def test_default_inference_params(self, model_name, chat_messages): assert "paris" in first_reply.content.lower(), "First reply does not contain 'paris'" assert first_reply.meta, "First reply has no metadata" - -@pytest.fixture -def chat_messages(): - messages = [ - ChatMessage.from_system("\\nYou are a helpful assistant, be super brief in your responses."), - ChatMessage.from_user("What's the capital of France?"), - ] - return messages - - -class TestMetaLlama2ChatAdapter: - @pytest.mark.integration - def test_prepare_body_with_default_params(self) -> None: - # leave this test as integration because we really need only tokenizer from HF - # that way we can ensure prompt chat message formatting - layer = MetaLlama2ChatAdapter(truncate=True, generation_kwargs={}) - prompt = "Hello, how are you?" - expected_body = {"prompt": "[INST] Hello, how are you? [/INST]", "max_gen_len": 512} - - body = layer.prepare_body([ChatMessage.from_user(prompt)]) - - assert body == expected_body - + @pytest.mark.parametrize("model_name", MODELS_TO_TEST_WITH_TOOLS) @pytest.mark.integration - def test_prepare_body_with_custom_inference_params(self) -> None: - # leave this test as integration because we really need only tokenizer from HF - # that way we can ensure prompt chat message formatting - layer = MetaLlama2ChatAdapter( - truncate=True, - generation_kwargs={"temperature": 0.7, "top_p": 0.8, "top_k": 5, "stop_sequences": ["CUSTOM_STOP"]}, - ) - prompt = "Hello, how are you?" - - # expected body is different because stop_sequences and top_k are not supported by MetaLlama2 - expected_body = { - "prompt": "[INST] Hello, how are you? [/INST]", - "max_gen_len": 69, - "temperature": 0.7, - "top_p": 0.8, + def test_tools_use(self, model_name): + """ + Test function calling with AWS Bedrock Anthropic adapter + """ + # See https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ToolConfiguration.html + tool_config = { + "tools": [ + { + "toolSpec": { + "name": "top_song", + "description": "Get the most popular song played on a radio station.", + "inputSchema": { + "json": { + "type": "object", + "properties": { + "sign": { + "type": "string", + "description": "The call sign for the radio station " + "for which you want the most popular song. " + "Example calls signs are WZPZ and WKRP.", + } + }, + "required": ["sign"], + } + }, + } + } + ], + # See https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ToolChoice.html + "toolChoice": {"auto": {}}, } - body = layer.prepare_body( - [ChatMessage.from_user(prompt)], - temperature=0.7, - top_p=0.8, - top_k=5, - max_gen_len=69, - stop_sequences=["CUSTOM_STOP"], - ) - - assert body == expected_body - - @pytest.mark.integration - def test_get_responses(self) -> None: - adapter = MetaLlama2ChatAdapter(truncate=True, generation_kwargs={}) - response_body = {"generation": "This is a single response."} - expected_response = "This is a single response." - response_message = adapter.get_responses(response_body) - # assert that the type of each item in the list is a ChatMessage - for message in response_message: - assert isinstance(message, ChatMessage) - - assert response_message == [ChatMessage.from_assistant(expected_response)] - - @pytest.mark.parametrize("model_name", MODELS_TO_TEST) - @pytest.mark.integration - def test_default_inference_params(self, model_name, chat_messages): + messages = [] + messages.append(ChatMessage.from_user("What is the most popular song on WZPZ?")) client = AmazonBedrockChatGenerator(model=model_name) - response = client.run(chat_messages) - - assert "replies" in response, "Response does not contain 'replies' key" + response = client.run(messages=messages, generation_kwargs={"toolConfig": tool_config}) replies = response["replies"] assert isinstance(replies, list), "Replies is not a list" assert len(replies) > 0, "No replies received" @@ -552,32 +248,70 @@ def test_default_inference_params(self, model_name, chat_messages): assert isinstance(first_reply, ChatMessage), "First reply is not a ChatMessage instance" assert first_reply.content, "First reply has no content" assert ChatMessage.is_from(first_reply, ChatRole.ASSISTANT), "First reply is not from the assistant" - assert "paris" in first_reply.content.lower(), "First reply does not contain 'paris'" assert first_reply.meta, "First reply has no metadata" - if first_reply.meta and "usage" in first_reply.meta: - assert "prompt_tokens" in first_reply.meta["usage"] - assert "completion_tokens" in first_reply.meta["usage"] - - @pytest.mark.parametrize("model_name", MODELS_TO_TEST) + # Some models return thinking message as first and the second one as the tool call + if len(replies) > 1: + second_reply = replies[1] + assert isinstance(second_reply, ChatMessage), "Second reply is not a ChatMessage instance" + assert second_reply.content, "Second reply has no content" + assert ChatMessage.is_from(second_reply, ChatRole.ASSISTANT), "Second reply is not from the assistant" + tool_call = json.loads(second_reply.content) + assert "toolUseId" in tool_call, "Tool call does not contain 'toolUseId' key" + assert tool_call["name"] == "top_song", f"Tool call {tool_call} does not contain the correct 'name' value" + assert "input" in tool_call, f"Tool call {tool_call} does not contain 'input' key" + assert ( + tool_call["input"]["sign"] == "WZPZ" + ), f"Tool call {tool_call} does not contain the correct 'input' value" + else: + # case where the model returns the tool call as the first message + # double check that the tool call is correct + tool_call = json.loads(first_reply.content) + assert "toolUseId" in tool_call, "Tool call does not contain 'toolUseId' key" + assert tool_call["name"] == "top_song", f"Tool call {tool_call} does not contain the correct 'name' value" + assert "input" in tool_call, f"Tool call {tool_call} does not contain 'input' key" + assert ( + tool_call["input"]["sign"] == "WZPZ" + ), f"Tool call {tool_call} does not contain the correct 'input' value" + + @pytest.mark.parametrize("model_name", STREAMING_TOOL_MODELS) @pytest.mark.integration - def test_default_inference_with_streaming(self, model_name, chat_messages): - streaming_callback_called = False - paris_found_in_response = False - - def streaming_callback(chunk: StreamingChunk): - nonlocal streaming_callback_called, paris_found_in_response - streaming_callback_called = True - assert isinstance(chunk, StreamingChunk) - assert chunk.content is not None - if not paris_found_in_response: - paris_found_in_response = "paris" in chunk.content.lower() - - client = AmazonBedrockChatGenerator(model=model_name, streaming_callback=streaming_callback) - response = client.run(chat_messages) + def test_tools_use_with_streaming(self, model_name): + """ + Test function calling with AWS Bedrock Anthropic adapter + """ + # See https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ToolConfiguration.html + tool_config = { + "tools": [ + { + "toolSpec": { + "name": "top_song", + "description": "Get the most popular song played on a radio station.", + "inputSchema": { + "json": { + "type": "object", + "properties": { + "sign": { + "type": "string", + "description": "The call sign for the radio station " + "for which you want the most popular song. Example " + "calls signs are WZPZ and WKRP.", + } + }, + "required": ["sign"], + } + }, + } + } + ], + # See https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ToolChoice.html + "toolChoice": {"auto": {}}, + } - assert streaming_callback_called, "Streaming callback was not called" - assert paris_found_in_response, "The streaming callback response did not contain 'paris'" + messages = [] + messages.append(ChatMessage.from_user("What is the most popular song on WZPZ?")) + client = AmazonBedrockChatGenerator(model=model_name, streaming_callback=print_streaming_chunk) + response = client.run(messages=messages, generation_kwargs={"toolConfig": tool_config}) replies = response["replies"] assert isinstance(replies, list), "Replies is not a list" assert len(replies) > 0, "No replies received" @@ -586,5 +320,139 @@ def streaming_callback(chunk: StreamingChunk): assert isinstance(first_reply, ChatMessage), "First reply is not a ChatMessage instance" assert first_reply.content, "First reply has no content" assert ChatMessage.is_from(first_reply, ChatRole.ASSISTANT), "First reply is not from the assistant" - assert "paris" in first_reply.content.lower(), "First reply does not contain 'paris'" assert first_reply.meta, "First reply has no metadata" + + # Some models return thinking message as first and the second one as the tool call + if len(replies) > 1: + second_reply = replies[1] + assert isinstance(second_reply, ChatMessage), "Second reply is not a ChatMessage instance" + assert second_reply.content, "Second reply has no content" + assert ChatMessage.is_from(second_reply, ChatRole.ASSISTANT), "Second reply is not from the assistant" + tool_call = json.loads(second_reply.content) + assert "toolUseId" in tool_call, "Tool call does not contain 'toolUseId' key" + assert tool_call["name"] == "top_song", f"Tool call {tool_call} does not contain the correct 'name' value" + assert "input" in tool_call, f"Tool call {tool_call} does not contain 'input' key" + assert ( + tool_call["input"]["sign"] == "WZPZ" + ), f"Tool call {tool_call} does not contain the correct 'input' value" + else: + # case where the model returns the tool call as the first message + # double check that the tool call is correct + tool_call = json.loads(first_reply.content) + assert "toolUseId" in tool_call, "Tool call does not contain 'toolUseId' key" + assert tool_call["name"] == "top_song", f"Tool call {tool_call} does not contain the correct 'name' value" + assert "input" in tool_call, f"Tool call {tool_call} does not contain 'input' key" + assert ( + tool_call["input"]["sign"] == "WZPZ" + ), f"Tool call {tool_call} does not contain the correct 'input' value" + + def test_extract_replies_from_response(self, mock_boto3_session): + """ + Test that extract_replies_from_response correctly processes both text and tool use responses + """ + generator = AmazonBedrockChatGenerator(model="anthropic.claude-3-5-sonnet-20240620-v1:0") + + # Test case 1: Simple text response + text_response = { + "output": {"message": {"role": "assistant", "content": [{"text": "This is a test response"}]}}, + "stopReason": "complete", + "usage": {"inputTokens": 10, "outputTokens": 20, "totalTokens": 30}, + } + + replies = generator.extract_replies_from_response(text_response) + assert len(replies) == 1 + assert replies[0].content == "This is a test response" + assert replies[0].role == ChatRole.ASSISTANT + assert replies[0].meta["model"] == "anthropic.claude-3-5-sonnet-20240620-v1:0" + assert replies[0].meta["finish_reason"] == "complete" + assert replies[0].meta["usage"] == {"prompt_tokens": 10, "completion_tokens": 20, "total_tokens": 30} + + # Test case 2: Tool use response + tool_response = { + "output": { + "message": { + "role": "assistant", + "content": [{"toolUse": {"toolUseId": "123", "name": "test_tool", "input": {"key": "value"}}}], + } + }, + "stopReason": "tool_call", + "usage": {"inputTokens": 15, "outputTokens": 25, "totalTokens": 40}, + } + + replies = generator.extract_replies_from_response(tool_response) + assert len(replies) == 1 + tool_content = json.loads(replies[0].content) + assert tool_content["toolUseId"] == "123" + assert tool_content["name"] == "test_tool" + assert tool_content["input"] == {"key": "value"} + assert replies[0].meta["finish_reason"] == "tool_call" + assert replies[0].meta["usage"] == {"prompt_tokens": 15, "completion_tokens": 25, "total_tokens": 40} + + # Test case 3: Mixed content response + mixed_response = { + "output": { + "message": { + "role": "assistant", + "content": [ + {"text": "Let me help you with that. I'll use the search tool to find the answer."}, + {"toolUse": {"toolUseId": "456", "name": "search_tool", "input": {"query": "test"}}}, + ], + } + }, + "stopReason": "complete", + "usage": {"inputTokens": 25, "outputTokens": 35, "totalTokens": 60}, + } + + replies = generator.extract_replies_from_response(mixed_response) + assert len(replies) == 2 + assert replies[0].content == "Let me help you with that. I'll use the search tool to find the answer." + tool_content = json.loads(replies[1].content) + assert tool_content["toolUseId"] == "456" + assert tool_content["name"] == "search_tool" + assert tool_content["input"] == {"query": "test"} + + def test_process_streaming_response(self, mock_boto3_session): + """ + Test that process_streaming_response correctly handles streaming events and accumulates responses + """ + generator = AmazonBedrockChatGenerator(model="anthropic.claude-3-5-sonnet-20240620-v1:0") + + streaming_chunks = [] + + def test_callback(chunk: StreamingChunk): + streaming_chunks.append(chunk) + + # Simulate a stream of events for both text and tool use + events = [ + {"contentBlockStart": {"start": {"text": ""}}}, + {"contentBlockDelta": {"delta": {"text": "Let me "}}}, + {"contentBlockDelta": {"delta": {"text": "help you."}}}, + {"contentBlockStop": {}}, + {"contentBlockStart": {"start": {"toolUse": {"toolUseId": "123", "name": "search_tool"}}}}, + {"contentBlockDelta": {"delta": {"toolUse": {"input": '{"query":'}}}}, + {"contentBlockDelta": {"delta": {"toolUse": {"input": '"test"}'}}}}, + {"contentBlockStop": {}}, + {"messageStop": {"stopReason": "complete"}}, + {"metadata": {"usage": {"inputTokens": 10, "outputTokens": 20, "totalTokens": 30}}}, + ] + + replies = generator.process_streaming_response(events, test_callback) + + # Verify streaming chunks were received for text content + assert len(streaming_chunks) == 2 + assert streaming_chunks[0].content == "Let me " + assert streaming_chunks[1].content == "help you." + + # Verify final replies + assert len(replies) == 2 + # Check text reply + assert replies[0].content == "Let me help you." + assert replies[0].meta["model"] == "anthropic.claude-3-5-sonnet-20240620-v1:0" + assert replies[0].meta["finish_reason"] == "complete" + assert replies[0].meta["usage"] == {"prompt_tokens": 10, "completion_tokens": 20, "total_tokens": 30} + + # Check tool use reply + tool_content = json.loads(replies[1].content) + assert tool_content["toolUseId"] == "123" + assert tool_content["name"] == "search_tool" + assert tool_content["input"] == {"query": "test"} From cb051b2f888020311cddc146271f96e676471b4f Mon Sep 17 00:00:00 2001 From: HaystackBot Date: Tue, 10 Dec 2024 14:51:10 +0000 Subject: [PATCH 132/229] Update the changelog --- integrations/amazon_bedrock/CHANGELOG.md | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/integrations/amazon_bedrock/CHANGELOG.md b/integrations/amazon_bedrock/CHANGELOG.md index 8e4350423..4aceb57bc 100644 --- a/integrations/amazon_bedrock/CHANGELOG.md +++ b/integrations/amazon_bedrock/CHANGELOG.md @@ -1,5 +1,12 @@ # Changelog +## [integrations/amazon_bedrock-v2.0.0] - 2024-12-10 + +### 🚀 Features + +- Update AmazonBedrockChatGenerator to use Converse API (BREAKING CHANGE) (#1219) + + ## [integrations/amazon_bedrock-v1.1.1] - 2024-12-03 ### 🐛 Bug Fixes From 87dd2cdd3d53b7800ebe262380443e8b4563e1ed Mon Sep 17 00:00:00 2001 From: Stefano Fiorucci Date: Tue, 10 Dec 2024 16:00:43 +0100 Subject: [PATCH 133/229] use instead of for in Ollama (#1239) --- .../components/generators/ollama/chat/chat_generator.py | 2 +- integrations/ollama/tests/test_chat_generator.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/integrations/ollama/src/haystack_integrations/components/generators/ollama/chat/chat_generator.py b/integrations/ollama/src/haystack_integrations/components/generators/ollama/chat/chat_generator.py index b1be7a2db..f598a6e42 100644 --- a/integrations/ollama/src/haystack_integrations/components/generators/ollama/chat/chat_generator.py +++ b/integrations/ollama/src/haystack_integrations/components/generators/ollama/chat/chat_generator.py @@ -109,7 +109,7 @@ def from_dict(cls, data: Dict[str, Any]) -> "OllamaChatGenerator": return default_from_dict(cls, data) def _message_to_dict(self, message: ChatMessage) -> Dict[str, str]: - return {"role": message.role.value, "content": message.content} + return {"role": message.role.value, "content": message.text} def _build_message_from_ollama_response(self, ollama_response: ChatResponse) -> ChatMessage: """ diff --git a/integrations/ollama/tests/test_chat_generator.py b/integrations/ollama/tests/test_chat_generator.py index 0308f42ec..b3df0fbf1 100644 --- a/integrations/ollama/tests/test_chat_generator.py +++ b/integrations/ollama/tests/test_chat_generator.py @@ -102,7 +102,7 @@ def test_build_message_from_ollama_response(self): observed = OllamaChatGenerator(model=model)._build_message_from_ollama_response(ollama_response) assert observed.role == "assistant" - assert observed.content == "Hello! How are you today?" + assert observed.text == "Hello! How are you today?" @pytest.mark.integration def test_run(self): @@ -121,7 +121,7 @@ def test_run(self): assert isinstance(response, dict) assert isinstance(response["replies"], list) - assert answer in response["replies"][0].content + assert answer in response["replies"][0].text @pytest.mark.integration def test_run_with_chat_history(self): @@ -137,7 +137,7 @@ def test_run_with_chat_history(self): assert isinstance(response, dict) assert isinstance(response["replies"], list) - assert "Manchester" in response["replies"][-1].content or "Glasgow" in response["replies"][-1].content + assert "Manchester" in response["replies"][-1].text or "Glasgow" in response["replies"][-1].text @pytest.mark.integration def test_run_model_unavailable(self): @@ -166,4 +166,4 @@ def test_run_with_streaming(self): assert isinstance(response, dict) assert isinstance(response["replies"], list) - assert "Manchester" in response["replies"][-1].content or "Glasgow" in response["replies"][-1].content + assert "Manchester" in response["replies"][-1].text or "Glasgow" in response["replies"][-1].text From 9bca071d6682f964d915bc4770014a7048ef85b4 Mon Sep 17 00:00:00 2001 From: HaystackBot Date: Tue, 10 Dec 2024 15:02:08 +0000 Subject: [PATCH 134/229] Update the changelog --- integrations/ollama/CHANGELOG.md | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/integrations/ollama/CHANGELOG.md b/integrations/ollama/CHANGELOG.md index 9e2e0a0cb..c28767257 100644 --- a/integrations/ollama/CHANGELOG.md +++ b/integrations/ollama/CHANGELOG.md @@ -1,5 +1,11 @@ # Changelog +## [integrations/ollama-v2.1.1] - 2024-12-10 + +### 🌀 Miscellaneous + +- Chore: use `text` instead of `content` for `ChatMessage` in Ollama (#1239) + ## [integrations/ollama-v2.1.0] - 2024-11-28 ### 🚀 Features From 574316d8724470ae4185fd8f0bc4440befc0f1cc Mon Sep 17 00:00:00 2001 From: Amna Mubashar Date: Tue, 10 Dec 2024 16:06:02 +0100 Subject: [PATCH 135/229] fix: logger for index deletion failures in Azure AI (#1240) * fix logging for failed index deletion --- integrations/azure_ai_search/tests/conftest.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/integrations/azure_ai_search/tests/conftest.py b/integrations/azure_ai_search/tests/conftest.py index 89369c87e..e741e3066 100644 --- a/integrations/azure_ai_search/tests/conftest.py +++ b/integrations/azure_ai_search/tests/conftest.py @@ -11,9 +11,12 @@ from haystack_integrations.document_stores.azure_ai_search import AzureAISearchDocumentStore +logger = logging.getLogger(__name__) + + # This is the approximate time in seconds it takes for the documents to be available in Azure Search index SLEEP_TIME_IN_SECONDS = 10 -MAX_WAIT_TIME_FOR_INDEX_DELETION = 5 +MAX_WAIT_TIME_FOR_INDEX_DELETION = 10 @pytest.fixture() @@ -75,8 +78,8 @@ def wait_for_index_deletion(client, index_name): try: client.delete_index(index_name) if not wait_for_index_deletion(client, index_name): - logging.error(f"Index {index_name} was not properly deleted.") + logger.error(f"Index {index_name} was not properly deleted.") except ResourceNotFoundError: - logging.info(f"Index {index_name} was already deleted or not found.") + logger.error(f"Index {index_name} was already deleted or not found.") except Exception as e: - logging.error(f"Unexpected error when deleting index {index_name}: {e}") + logger.error(f"Unexpected error when deleting index {index_name}: {e}") From 943f8e50cc9dc30ce6a8fd867bcb802dad2c2e74 Mon Sep 17 00:00:00 2001 From: Stefano Fiorucci Date: Tue, 10 Dec 2024 18:15:37 +0100 Subject: [PATCH 136/229] fix: GoogleAI - fix the content type of `ChatMessage` `content` from function (#1241) * fix Gemini * avoid directly accessing role --- .../generators/google_ai/chat/gemini.py | 49 ++++++++++--------- .../tests/generators/chat/test_chat_gemini.py | 13 ++--- 2 files changed, 32 insertions(+), 30 deletions(-) diff --git a/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/chat/gemini.py b/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/chat/gemini.py index ef7d583be..089b38b10 100644 --- a/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/chat/gemini.py +++ b/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/chat/gemini.py @@ -1,3 +1,4 @@ +import json import logging from typing import Any, Callable, Dict, List, Optional, Union @@ -36,12 +37,12 @@ class GoogleAIGeminiChatGenerator: messages = [ChatMessage.from_user("What is the most interesting thing you know?")] res = gemini_chat.run(messages=messages) for reply in res["replies"]: - print(reply.content) + print(reply.text) messages += res["replies"] + [ChatMessage.from_user("Tell me more about it")] res = gemini_chat.run(messages=messages) for reply in res["replies"]: - print(reply.content) + print(reply.text) ``` @@ -85,14 +86,14 @@ def get_current_weather(location: str, unit: str = "celsius") -> str: gemini_chat = GoogleAIGeminiChatGenerator(model="gemini-pro", api_key=Secret.from_token(""), tools=[tool]) - messages = [ChatMessage.from_user(content = "What is the temperature in celsius in Berlin?")] + messages = [ChatMessage.from_user("What is the temperature in celsius in Berlin?")] res = gemini_chat.run(messages=messages) - weather = get_current_weather(**res["replies"][0].content) + weather = get_current_weather(**json.loads(res["replies"][0].text)) messages += res["replies"] + [ChatMessage.from_function(content=weather, name="get_current_weather")] res = gemini_chat.run(messages=messages) for reply in res["replies"]: - print(reply.content) + print(reply.text) ``` """ @@ -230,45 +231,45 @@ def _convert_part(self, part: Union[str, ByteStream, Part]) -> Part: raise ValueError(msg) def _message_to_part(self, message: ChatMessage) -> Part: - if message.role == ChatRole.ASSISTANT and message.name: + if message.is_from(ChatRole.ASSISTANT) and message.name: p = Part() p.function_call.name = message.name p.function_call.args = {} - for k, v in message.content.items(): + for k, v in json.loads(message.text).items(): p.function_call.args[k] = v return p - elif message.role in {ChatRole.SYSTEM, ChatRole.ASSISTANT}: + elif message.is_from(ChatRole.SYSTEM) or message.is_from(ChatRole.ASSISTANT): p = Part() - p.text = message.content + p.text = message.text return p - elif message.role == ChatRole.FUNCTION: + elif message.is_from(ChatRole.FUNCTION): p = Part() p.function_response.name = message.name - p.function_response.response = message.content + p.function_response.response = message.text return p - elif message.role == ChatRole.USER: - return self._convert_part(message.content) + elif message.is_from(ChatRole.USER): + return self._convert_part(message.text) def _message_to_content(self, message: ChatMessage) -> Content: - if message.role == ChatRole.ASSISTANT and message.name: + if message.is_from(ChatRole.ASSISTANT) and message.name: part = Part() part.function_call.name = message.name part.function_call.args = {} - for k, v in message.content.items(): + for k, v in json.loads(message.text).items(): part.function_call.args[k] = v - elif message.role in {ChatRole.SYSTEM, ChatRole.ASSISTANT}: + elif message.is_from(ChatRole.SYSTEM) or message.is_from(ChatRole.ASSISTANT): part = Part() - part.text = message.content - elif message.role == ChatRole.FUNCTION: + part.text = message.text + elif message.is_from(ChatRole.FUNCTION): part = Part() part.function_response.name = message.name - part.function_response.response = message.content - elif message.role == ChatRole.USER: - part = self._convert_part(message.content) + part.function_response.response = message.text + elif message.is_from(ChatRole.USER): + part = self._convert_part(message.text) else: msg = f"Unsupported message role {message.role}" raise ValueError(msg) - role = "user" if message.role in [ChatRole.USER, ChatRole.FUNCTION] else "model" + role = "user" if message.is_from(ChatRole.USER) or message.is_from(ChatRole.FUNCTION) else "model" return Content(parts=[part], role=role) @component.output_types(replies=List[ChatMessage]) @@ -338,7 +339,7 @@ def _get_response(self, response_body: GenerateContentResponse) -> List[ChatMess elif part.function_call: candidate_metadata["function_call"] = part.function_call new_message = ChatMessage.from_assistant( - content=dict(part.function_call.args.items()), meta=candidate_metadata + content=json.dumps(dict(part.function_call.args)), meta=candidate_metadata ) new_message.name = part.function_call.name replies.append(new_message) @@ -366,7 +367,7 @@ def _get_stream_response( replies.append(ChatMessage.from_assistant(content=content, meta=metadata)) elif "function_call" in part and len(part["function_call"]) > 0: metadata["function_call"] = part["function_call"] - content = part["function_call"]["args"] + content = json.dumps(dict(part["function_call"]["args"])) new_message = ChatMessage.from_assistant(content=content, meta=metadata) new_message.name = part["function_call"]["name"] replies.append(new_message) diff --git a/integrations/google_ai/tests/generators/chat/test_chat_gemini.py b/integrations/google_ai/tests/generators/chat/test_chat_gemini.py index cb42f0ff8..b8658a4dd 100644 --- a/integrations/google_ai/tests/generators/chat/test_chat_gemini.py +++ b/integrations/google_ai/tests/generators/chat/test_chat_gemini.py @@ -1,3 +1,4 @@ +import json import os from unittest.mock import patch @@ -223,9 +224,9 @@ def get_current_weather(location: str, unit: str = "celsius"): # noqa: ARG001 # check the first response is a function call chat_message = response["replies"][0] assert "function_call" in chat_message.meta - assert chat_message.content == {"location": "Berlin", "unit": "celsius"} + assert json.loads(chat_message.text) == {"location": "Berlin", "unit": "celsius"} - weather = get_current_weather(**chat_message.content) + weather = get_current_weather(**json.loads(chat_message.text)) messages += response["replies"] + [ChatMessage.from_function(content=weather, name="get_current_weather")] response = gemini_chat.run(messages=messages) assert "replies" in response @@ -235,7 +236,7 @@ def get_current_weather(location: str, unit: str = "celsius"): # noqa: ARG001 # check the second response is not a function call chat_message = response["replies"][0] assert "function_call" not in chat_message.meta - assert isinstance(chat_message.content, str) + assert isinstance(chat_message.text, str) @pytest.mark.skipif(not os.environ.get("GOOGLE_API_KEY", None), reason="GOOGLE_API_KEY env var not set") @@ -269,9 +270,9 @@ def get_current_weather(location: str, unit: str = "celsius"): # noqa: ARG001 # check the first response is a function call chat_message = response["replies"][0] assert "function_call" in chat_message.meta - assert chat_message.content == {"location": "Berlin", "unit": "celsius"} + assert json.loads(chat_message.text) == {"location": "Berlin", "unit": "celsius"} - weather = get_current_weather(**response["replies"][0].content) + weather = get_current_weather(**json.loads(response["replies"][0].text)) messages += response["replies"] + [ChatMessage.from_function(content=weather, name="get_current_weather")] response = gemini_chat.run(messages=messages) assert "replies" in response @@ -281,7 +282,7 @@ def get_current_weather(location: str, unit: str = "celsius"): # noqa: ARG001 # check the second response is not a function call chat_message = response["replies"][0] assert "function_call" not in chat_message.meta - assert isinstance(chat_message.content, str) + assert isinstance(chat_message.text, str) @pytest.mark.skipif(not os.environ.get("GOOGLE_API_KEY", None), reason="GOOGLE_API_KEY env var not set") From adba1662eaaf12d5a7812e47b990e1c553cb5d98 Mon Sep 17 00:00:00 2001 From: HaystackBot Date: Tue, 10 Dec 2024 17:16:59 +0000 Subject: [PATCH 137/229] Update the changelog --- integrations/google_ai/CHANGELOG.md | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/integrations/google_ai/CHANGELOG.md b/integrations/google_ai/CHANGELOG.md index 7171b0069..404303412 100644 --- a/integrations/google_ai/CHANGELOG.md +++ b/integrations/google_ai/CHANGELOG.md @@ -1,5 +1,19 @@ # Changelog +## [integrations/google_ai-v4.0.0] - 2024-12-10 + +### 🐛 Bug Fixes + +- GoogleAI - fix the content type of `ChatMessage` `content` from function (#1241) + +### 🧹 Chores + +- Fix linting/isort (#1215) + +### 🌀 Miscellaneous + +- Chore: use class methods to create `ChatMessage` (#1222) + ## [integrations/google_ai-v3.0.2] - 2024-11-19 ### 🐛 Bug Fixes From 0811b3b305aff6af7838e0078df8ee70ec853cac Mon Sep 17 00:00:00 2001 From: Stefano Fiorucci Date: Wed, 11 Dec 2024 14:54:35 +0100 Subject: [PATCH 138/229] fix vertex (#1242) --- .../generators/google_vertex/chat/gemini.py | 39 ++++++++++--------- 1 file changed, 20 insertions(+), 19 deletions(-) diff --git a/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/chat/gemini.py b/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/chat/gemini.py index c94367b41..2309ca718 100644 --- a/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/chat/gemini.py +++ b/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/chat/gemini.py @@ -1,3 +1,4 @@ +import json import logging from typing import Any, Callable, Dict, Iterable, List, Optional, Union @@ -41,7 +42,7 @@ class VertexAIGeminiChatGenerator: messages = [ChatMessage.from_user("Tell me the name of a movie")] res = gemini_chat.run(messages) - print(res["replies"][0].content) + print(res["replies"][0].text) >>> The Shawshank Redemption ``` """ @@ -209,31 +210,31 @@ def _convert_part(self, part: Union[str, ByteStream, Part]) -> Part: def _message_to_part(self, message: ChatMessage) -> Part: if message.role == ChatRole.ASSISTANT and message.name: p = Part.from_dict({"function_call": {"name": message.name, "args": {}}}) - for k, v in message.content.items(): + for k, v in json.loads(message.text).items(): p.function_call.args[k] = v return p - elif message.role in {ChatRole.SYSTEM, ChatRole.ASSISTANT}: - return Part.from_text(message.content) - elif message.role == ChatRole.FUNCTION: - return Part.from_function_response(name=message.name, response=message.content) - elif message.role == ChatRole.USER: - return self._convert_part(message.content) + elif message.is_from(ChatRole.SYSTEM) or message.is_from(ChatRole.ASSISTANT): + return Part.from_text(message.text) + elif message.is_from(ChatRole.FUNCTION): + return Part.from_function_response(name=message.name, response=message.text) + elif message.is_from(ChatRole.USER): + return self._convert_part(message.text) def _message_to_content(self, message: ChatMessage) -> Content: - if message.role == ChatRole.ASSISTANT and message.name: + if message.is_from(ChatRole.ASSISTANT) and message.name: part = Part.from_dict({"function_call": {"name": message.name, "args": {}}}) - for k, v in message.content.items(): + for k, v in json.loads(message.text).items(): part.function_call.args[k] = v - elif message.role in {ChatRole.SYSTEM, ChatRole.ASSISTANT}: - part = Part.from_text(message.content) - elif message.role == ChatRole.FUNCTION: - part = Part.from_function_response(name=message.name, response=message.content) - elif message.role == ChatRole.USER: - part = self._convert_part(message.content) + elif message.is_from(ChatRole.SYSTEM) or message.is_from(ChatRole.ASSISTANT): + part = Part.from_text(message.text) + elif message.is_from(ChatRole.FUNCTION): + part = Part.from_function_response(name=message.name, response=message.text) + elif message.is_from(ChatRole.USER): + part = self._convert_part(message.text) else: msg = f"Unsupported message role {message.role}" raise ValueError(msg) - role = "user" if message.role in [ChatRole.USER, ChatRole.FUNCTION] else "model" + role = "user" if message.is_from(ChatRole.USER) or message.is_from(ChatRole.FUNCTION) else "model" return Content(parts=[part], role=role) @component.output_types(replies=List[ChatMessage]) @@ -283,7 +284,7 @@ def _get_response(self, response_body: GenerationResponse) -> List[ChatMessage]: elif part.function_call: metadata["function_call"] = part.function_call new_message = ChatMessage.from_assistant( - content=dict(part.function_call.args.items()), meta=metadata + content=json.dumps(dict(part.function_call.args)), meta=metadata ) new_message.name = part.function_call.name replies.append(new_message) @@ -311,7 +312,7 @@ def _get_stream_response( replies.append(ChatMessage.from_assistant(content, meta=metadata)) elif part.function_call: metadata["function_call"] = part.function_call - content = dict(part.function_call.args.items()) + content = json.dumps(dict(part.function_call.args)) new_message = ChatMessage.from_assistant(content, meta=metadata) new_message.name = part.function_call.name replies.append(new_message) From 04d21da9f8b404881ef4a6e680af08f4efae7a2e Mon Sep 17 00:00:00 2001 From: HaystackBot Date: Wed, 11 Dec 2024 13:56:07 +0000 Subject: [PATCH 139/229] Update the changelog --- integrations/google_vertex/CHANGELOG.md | 57 ++++++++++++++++++++++--- 1 file changed, 50 insertions(+), 7 deletions(-) diff --git a/integrations/google_vertex/CHANGELOG.md b/integrations/google_vertex/CHANGELOG.md index ea2a8fb18..71f433509 100644 --- a/integrations/google_vertex/CHANGELOG.md +++ b/integrations/google_vertex/CHANGELOG.md @@ -1,15 +1,30 @@ # Changelog +## [integrations/google_vertex-v4.0.0] - 2024-12-11 + +### 🐛 Bug Fixes + +- Fix: Google Vertex - fix the content type of `ChatMessage` `content` from function (#1242) + +### 🧹 Chores + +- Fix linting/isort (#1215) + +### 🌀 Miscellaneous + +- Chore: use class methods to create `ChatMessage` (#1222) + ## [integrations/google_vertex-v3.0.0] - 2024-11-14 ### 🐛 Bug Fixes - VertexAIGeminiGenerator - remove support for tools and change output type (#1180) -### ⚙️ Miscellaneous Tasks +### 🧹 Chores - Fix Vertex tests (#1163) + ## [integrations/google_vertex-v2.2.0] - 2024-10-23 ### 🐛 Bug Fixes @@ -17,10 +32,11 @@ - Make "project-id" parameter optional during initialization (#1141) - Make project-id optional in all VertexAI generators (#1147) -### ⚙️ Miscellaneous Tasks +### ⚙️ CI - Adopt uv as installer (#1142) + ## [integrations/google_vertex-v2.1.0] - 2024-10-04 ### 🚀 Features @@ -40,22 +56,37 @@ - Do not retry tests in `hatch run test` command (#954) - Add tests for VertexAIChatGeminiGenerator and migrate from preview package in vertexai (#1042) -### ⚙️ Miscellaneous Tasks +### ⚙️ CI - Retry tests to reduce flakyness (#836) + +### 🧹 Chores + - Update ruff invocation to include check parameter (#853) - Update ruff linting scripts and settings (#1105) +### 🌀 Miscellaneous + +- Chore: add license classifiers (#680) +- Chore: change the pydoc renderer class (#718) +- Ci: install `pytest-rerunfailures` where needed; add retry config to `test-cov` script (#845) +- Ping `protobuf` to `<5.28` to fix Google Vertex Components serialization (#1050) +- Update docstrings to remove vertexai preview package (#1074) +- Chore: Unpin protobuf dependency in Google Vertex integration (#1085) +- Chore: pin `google-cloud-aiplatform>=1.61` and fix tests (#1124) + ## [integrations/google_vertex-v1.1.0] - 2024-03-28 +### 🌀 Miscellaneous + +- Add pyarrow as required dependency (#629) + ## [integrations/google_vertex-v1.0.0] - 2024-03-27 ### 🐛 Bug Fixes - Fix order of API docs (#447) -This PR will also push the docs to Readme - ### 📚 Documentation - Update category slug (#442) @@ -63,20 +94,32 @@ This PR will also push the docs to Readme - Small consistency improvements (#536) - Disable-class-def (#556) -### Google_vertex +### 🌀 Miscellaneous - Create api docs (#355) +- Make tests show coverage (#566) +- Remove references to Python 3.7 (#601) +- Google Generators: change `answers` to `replies` (#626) ## [integrations/google_vertex-v0.2.0] - 2024-01-26 +### 🌀 Miscellaneous + +- Refact!: change import paths (#273) + ## [integrations/google_vertex-v0.1.0] - 2024-01-03 ### 🐛 Bug Fixes - The default model of VertexAIImagegenerator (#158) -### ⚙️ Miscellaneous Tasks +### 🧹 Chores - Replace - with _ (#114) +### 🌀 Miscellaneous + +- Change metadata to meta (#152) +- Add VertexAI prefix to GeminiGenerator and GeminiChatGenerator components (#166) + From 4af651a5c0a7b071cfad7779904e6022f9a770a2 Mon Sep 17 00:00:00 2001 From: HaystackBot Date: Wed, 11 Dec 2024 15:34:24 +0000 Subject: [PATCH 140/229] Update the changelog --- integrations/langfuse/CHANGELOG.md | 47 ++++++++++++++++++++++++++---- 1 file changed, 42 insertions(+), 5 deletions(-) diff --git a/integrations/langfuse/CHANGELOG.md b/integrations/langfuse/CHANGELOG.md index 7cf1cc0c4..0ecb42b48 100644 --- a/integrations/langfuse/CHANGELOG.md +++ b/integrations/langfuse/CHANGELOG.md @@ -1,23 +1,42 @@ # Changelog +## [unreleased] + +### 🚀 Features + +- Warn if LangfuseTracer initialized without tracing enabled (#1231) + +### 🧹 Chores + +- Use text instead of content for ChatMessage in Llama.cpp, Langfuse and Mistral (#1238) + +### 🌀 Miscellaneous + +- Chore: Fix tracing_context_var lint errors (#1220) + ## [integrations/langfuse-v0.6.0] - 2024-11-18 ### 🚀 Features - Add support for ttft (#1161) -### ⚙️ Miscellaneous Tasks +### ⚙️ CI - Adopt uv as installer (#1142) +### 🌀 Miscellaneous + +- Fixed TypeError in LangfuseTrace (#1184) + ## [integrations/langfuse-v0.5.0] - 2024-10-01 -### ⚙️ Miscellaneous Tasks +### 🧹 Chores - Update ruff linting scripts and settings (#1105) -### Langfuse +### 🌀 Miscellaneous +- Fix: Add delay to flush the Langfuse traces (#1091) - Add invocation_context to identify traces (#1089) ## [integrations/langfuse-v0.4.0] - 2024-09-17 @@ -38,14 +57,27 @@ - Do not retry tests in `hatch run test` command (#954) -### ⚙️ Miscellaneous Tasks +### ⚙️ CI - Retry tests to reduce flakyness (#836) + +### 🧹 Chores + - `Langfuse` - replace DynamicChatPromptBuilder with ChatPromptBuilder (#925) - Remove all `DynamicChatPromptBuilder` references in Langfuse integration (#931) +### 🌀 Miscellaneous + +- Ci: install `pytest-rerunfailures` where needed; add retry config to `test-cov` script (#845) +- Chore: Update Langfuse README to avoid common initialization issues (#952) +- Chore: langfuse - ruff update, don't ruff tests (#992) + ## [integrations/langfuse-v0.2.0] - 2024-06-18 +### 🌀 Miscellaneous + +- Feat: add support for Azure generators (#815) + ## [integrations/langfuse-v0.1.0] - 2024-06-13 ### 🚀 Features @@ -56,8 +88,13 @@ - Performance optimizations and value error when streaming in langfuse (#798) -### ⚙️ Miscellaneous Tasks +### 🧹 Chores - Use ChatMessage to_openai_format, update unit tests, pydocs (#725) +### 🌀 Miscellaneous + +- Chore: change the pydoc renderer class (#718) +- Docs: add missing api references (#728) + From 27a91ff750db1b673a8688dc3d225011bfc74fc7 Mon Sep 17 00:00:00 2001 From: tstadel <60758086+tstadel@users.noreply.github.com> Date: Wed, 11 Dec 2024 17:40:13 +0100 Subject: [PATCH 141/229] feat: support model_arn in AmazonBedrockGenerator (#1244) * feat: support model_arn in AmazonBedrockGenerator * add test * fix tests * apply feedback * fix lint --- .../generators/amazon_bedrock/generator.py | 53 +++++++++++++++--- .../amazon_bedrock/tests/test_generator.py | 55 ++++++++++++++++++- 2 files changed, 100 insertions(+), 8 deletions(-) diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/generator.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/generator.py index 941fdbf71..79dc07cdc 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/generator.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/generator.py @@ -1,7 +1,7 @@ import json import logging import re -from typing import Any, Callable, ClassVar, Dict, List, Optional, Type +from typing import Any, Callable, ClassVar, Dict, List, Literal, Optional, Type, get_args from botocore.config import Config from botocore.exceptions import ClientError @@ -75,6 +75,26 @@ class AmazonBedrockGenerator: r"([a-z]{2}\.)?mistral.*": MistralAdapter, } + SUPPORTED_MODEL_FAMILIES: ClassVar[Dict[str, Type[BedrockModelAdapter]]] = { + "amazon.titan-text": AmazonTitanAdapter, + "ai21.j2": AI21LabsJurassic2Adapter, + "cohere.command": CohereCommandAdapter, + "cohere.command-r": CohereCommandRAdapter, + "anthropic.claude": AnthropicClaudeAdapter, + "meta.llama": MetaLlamaAdapter, + "mistral": MistralAdapter, + } + + MODEL_FAMILIES = Literal[ + "amazon.titan-text", + "ai21.j2", + "cohere.command", + "cohere.command-r", + "anthropic.claude", + "meta.llama", + "mistral", + ] + def __init__( self, model: str, @@ -89,6 +109,7 @@ def __init__( truncate: Optional[bool] = True, streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, boto3_config: Optional[Dict[str, Any]] = None, + model_family: Optional[MODEL_FAMILIES] = None, **kwargs, ): """ @@ -105,6 +126,8 @@ def __init__( :param streaming_callback: A callback function that is called when a new token is received from the stream. The callback function accepts StreamingChunk as an argument. :param boto3_config: The configuration for the boto3 client. + :param model_family: The model family to use. If not provided, the model adapter is selected based on the model + name. :param kwargs: Additional keyword arguments to be passed to the model. These arguments are specific to the model. You can find them in the model's documentation. :raises ValueError: If the model name is empty or None. @@ -125,6 +148,7 @@ def __init__( self.streaming_callback = streaming_callback self.boto3_config = boto3_config self.kwargs = kwargs + self.model_family = model_family def resolve_secret(secret: Optional[Secret]) -> Optional[str]: return secret.resolve_value() if secret else None @@ -163,10 +187,7 @@ def resolve_secret(secret: Optional[Secret]) -> Optional[str]: max_length=self.max_length or 100, ) - model_adapter_cls = self.get_model_adapter(model=model) - if not model_adapter_cls: - msg = f"AmazonBedrockGenerator doesn't support the model {model}." - raise AmazonBedrockConfigurationError(msg) + model_adapter_cls = self.get_model_adapter(model=model, model_family=model_family) self.model_adapter = model_adapter_cls(model_kwargs=model_input_kwargs, max_length=self.max_length) def _ensure_token_limit(self, prompt: str) -> str: @@ -250,17 +271,34 @@ def run( return {"replies": replies} @classmethod - def get_model_adapter(cls, model: str) -> Optional[Type[BedrockModelAdapter]]: + def get_model_adapter(cls, model: str, model_family: Optional[str] = None) -> Type[BedrockModelAdapter]: """ Gets the model adapter for the given model. + If `model_family` is provided, the adapter for the model family is returned. + If `model_family` is not provided, the adapter is auto-detected based on the model name. + :param model: The model name. + :param model_family: The model family. :returns: The model adapter class, or None if no adapter is found. + :raises AmazonBedrockConfigurationError: If the model family is not supported or the model cannot be + auto-detected. """ + if model_family: + if model_family not in cls.SUPPORTED_MODEL_FAMILIES: + msg = f"Model family {model_family} is not supported. Must be one of {get_args(cls.MODEL_FAMILIES)}." + raise AmazonBedrockConfigurationError(msg) + return cls.SUPPORTED_MODEL_FAMILIES[model_family] + for pattern, adapter in cls.SUPPORTED_MODEL_PATTERNS.items(): if re.fullmatch(pattern, model): return adapter - return None + + msg = ( + f"Could not auto-detect model family of {model}. " + f"`model_family` parameter must be one of {get_args(cls.MODEL_FAMILIES)}." + ) + raise AmazonBedrockConfigurationError(msg) def to_dict(self) -> Dict[str, Any]: """ @@ -282,6 +320,7 @@ def to_dict(self) -> Dict[str, Any]: truncate=self.truncate, streaming_callback=callback_name, boto3_config=self.boto3_config, + model_family=self.model_family, **self.kwargs, ) diff --git a/integrations/amazon_bedrock/tests/test_generator.py b/integrations/amazon_bedrock/tests/test_generator.py index 54b185da5..3d2cbc01f 100644 --- a/integrations/amazon_bedrock/tests/test_generator.py +++ b/integrations/amazon_bedrock/tests/test_generator.py @@ -4,6 +4,9 @@ import pytest from haystack.dataclasses import StreamingChunk +from haystack_integrations.common.amazon_bedrock.errors import ( + AmazonBedrockConfigurationError, +) from haystack_integrations.components.generators.amazon_bedrock import AmazonBedrockGenerator from haystack_integrations.components.generators.amazon_bedrock.adapters import ( AI21LabsJurassic2Adapter, @@ -48,6 +51,7 @@ def test_to_dict(mock_boto3_session: Any, boto3_config: Optional[Dict[str, Any]] "temperature": 10, "streaming_callback": None, "boto3_config": boto3_config, + "model_family": None, }, } @@ -79,6 +83,7 @@ def test_from_dict(mock_boto3_session: Any, boto3_config: Optional[Dict[str, Any "model": "anthropic.claude-v2", "max_length": 99, "boto3_config": boto3_config, + "model_family": "anthropic.claude", }, } ) @@ -86,6 +91,7 @@ def test_from_dict(mock_boto3_session: Any, boto3_config: Optional[Dict[str, Any assert generator.max_length == 99 assert generator.model == "anthropic.claude-v2" assert generator.boto3_config == boto3_config + assert generator.model_family == "anthropic.claude" def test_default_constructor(mock_boto3_session, set_env_variables): @@ -294,7 +300,6 @@ def test_long_prompt_is_not_truncated_when_truncate_false(mock_boto3_session): ("eu.mistral.mixtral-8x7b-instruct-v0:1", MistralAdapter), # cross-region inference ("us.mistral.mistral-large-2402-v1:0", MistralAdapter), # cross-region inference ("mistral.mistral-medium-v8:0", MistralAdapter), # artificial - ("unknown_model", None), ], ) def test_get_model_adapter(model: str, expected_model_adapter: Optional[Type[BedrockModelAdapter]]): @@ -305,6 +310,54 @@ def test_get_model_adapter(model: str, expected_model_adapter: Optional[Type[Bed assert model_adapter == expected_model_adapter +@pytest.mark.parametrize( + "model_family, expected_model_adapter", + [ + ("anthropic.claude", AnthropicClaudeAdapter), + ("cohere.command", CohereCommandAdapter), + ("cohere.command-r", CohereCommandRAdapter), + ("ai21.j2", AI21LabsJurassic2Adapter), + ("amazon.titan-text", AmazonTitanAdapter), + ("meta.llama", MetaLlamaAdapter), + ("mistral", MistralAdapter), + ], +) +def test_get_model_adapter_with_model_family( + model_family: str, expected_model_adapter: Optional[Type[BedrockModelAdapter]] +): + """ + Test that the correct model adapter is returned for a given model model_family + """ + model_adapter = AmazonBedrockGenerator.get_model_adapter(model="arn:123435423", model_family=model_family) + assert model_adapter == expected_model_adapter + + +def test_get_model_adapter_with_invalid_model_family(): + """ + Test that an error is raised when an invalid model_family is provided + """ + with pytest.raises(AmazonBedrockConfigurationError): + AmazonBedrockGenerator.get_model_adapter(model="arn:123435423", model_family="invalid") + + +def test_get_model_adapter_auto_detect_family_fails(): + """ + Test that an error is raised when auto-detection of model_family fails + """ + with pytest.raises(AmazonBedrockConfigurationError): + AmazonBedrockGenerator.get_model_adapter(model="arn:123435423") + + +def test_get_model_adapter_model_family_over_auto_detection(): + """ + Test that the model_family is used over auto-detection + """ + model_adapter = AmazonBedrockGenerator.get_model_adapter( + model="cohere.command-text-v14", model_family="anthropic.claude" + ) + assert model_adapter == AnthropicClaudeAdapter + + class TestAnthropicClaudeAdapter: def test_default_init(self) -> None: adapter = AnthropicClaudeAdapter(model_kwargs={}, max_length=100) From 31d14a104530e161125e873f5792079de764f855 Mon Sep 17 00:00:00 2001 From: HaystackBot Date: Wed, 11 Dec 2024 16:43:11 +0000 Subject: [PATCH 142/229] Update the changelog --- integrations/amazon_bedrock/CHANGELOG.md | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/integrations/amazon_bedrock/CHANGELOG.md b/integrations/amazon_bedrock/CHANGELOG.md index 4aceb57bc..46eeea7b7 100644 --- a/integrations/amazon_bedrock/CHANGELOG.md +++ b/integrations/amazon_bedrock/CHANGELOG.md @@ -1,5 +1,12 @@ # Changelog +## [integrations/amazon_bedrock-v2.1.0] - 2024-12-11 + +### 🚀 Features + +- Support model_arn in AmazonBedrockGenerator (#1244) + + ## [integrations/amazon_bedrock-v2.0.0] - 2024-12-10 ### 🚀 Features From 63f20c03e3637ee248083631e978a89d2be48dbe Mon Sep 17 00:00:00 2001 From: Sebastian Husch Lee Date: Thu, 12 Dec 2024 09:35:30 +0100 Subject: [PATCH 143/229] chore: Update docstring and type of fuzziness (#1243) * Update docstring and type of fuzziness * Add test --- .../retrievers/opensearch/bm25_retriever.py | 14 ++++++--- .../opensearch/document_store.py | 12 ++++++-- .../opensearch/tests/test_bm25_retriever.py | 29 +++++++++++++++++++ 3 files changed, 48 insertions(+), 7 deletions(-) diff --git a/integrations/opensearch/src/haystack_integrations/components/retrievers/opensearch/bm25_retriever.py b/integrations/opensearch/src/haystack_integrations/components/retrievers/opensearch/bm25_retriever.py index 4a8478e2c..69288a5cf 100644 --- a/integrations/opensearch/src/haystack_integrations/components/retrievers/opensearch/bm25_retriever.py +++ b/integrations/opensearch/src/haystack_integrations/components/retrievers/opensearch/bm25_retriever.py @@ -27,7 +27,7 @@ def __init__( *, document_store: OpenSearchDocumentStore, filters: Optional[Dict[str, Any]] = None, - fuzziness: str = "AUTO", + fuzziness: Union[int, str] = "AUTO", top_k: int = 10, scale_score: bool = False, all_terms_must_match: bool = False, @@ -40,8 +40,14 @@ def __init__( :param document_store: An instance of OpenSearchDocumentStore to use with the Retriever. :param filters: Filters to narrow down the search for documents in the Document Store. - :param fuzziness: Fuzziness parameter for full-text queries to apply approximate string matching. - For more information, see [OpenSearch fuzzy query](https://opensearch.org/docs/latest/query-dsl/term/fuzzy/). + :param fuzziness: Determines how approximate string matching is applied in full-text queries. + This parameter sets the number of character edits (insertions, deletions, or substitutions) + required to transform one word into another. For example, the "fuzziness" between the words + "wined" and "wind" is 1 because only one edit is needed to match them. + + Use "AUTO" (the default) for automatic adjustment based on term length, which is optimal for + most scenarios. For detailed guidance, refer to the + [OpenSearch fuzzy query documentation](https://opensearch.org/docs/latest/query-dsl/term/fuzzy/). :param top_k: Maximum number of documents to return. :param scale_score: If `True`, scales the score of retrieved documents to a range between 0 and 1. This is useful when comparing documents across different indexes. @@ -153,7 +159,7 @@ def run( filters: Optional[Dict[str, Any]] = None, all_terms_must_match: Optional[bool] = None, top_k: Optional[int] = None, - fuzziness: Optional[str] = None, + fuzziness: Optional[Union[int, str]] = None, scale_score: Optional[bool] = None, custom_query: Optional[Dict[str, Any]] = None, ): diff --git a/integrations/opensearch/src/haystack_integrations/document_stores/opensearch/document_store.py b/integrations/opensearch/src/haystack_integrations/document_stores/opensearch/document_store.py index 4ec2420b3..6cb5295f0 100644 --- a/integrations/opensearch/src/haystack_integrations/document_stores/opensearch/document_store.py +++ b/integrations/opensearch/src/haystack_integrations/document_stores/opensearch/document_store.py @@ -340,7 +340,7 @@ def _bm25_retrieval( query: str, *, filters: Optional[Dict[str, Any]] = None, - fuzziness: str = "AUTO", + fuzziness: Union[int, str] = "AUTO", top_k: int = 10, scale_score: bool = False, all_terms_must_match: bool = False, @@ -357,8 +357,14 @@ def _bm25_retrieval( :param query: String to search in saved Documents' text. :param filters: Optional filters to narrow down the search space. - :param fuzziness: Fuzziness parameter passed to OpenSearch, defaults to "AUTO". see the official documentation - for valid [fuzziness values](https://www.elastic.co/guide/en/OpenSearch/reference/current/common-options.html#fuzziness) + :param fuzziness: Determines how approximate string matching is applied in full-text queries. + This parameter sets the number of character edits (insertions, deletions, or substitutions) + required to transform one word into another. For example, the "fuzziness" between the words + "wined" and "wind" is 1 because only one edit is needed to match them. + + Use "AUTO" (the default) for automatic adjustment based on term length, which is optimal for + most scenarios. For detailed guidance, refer to the + [OpenSearch fuzzy query documentation](https://opensearch.org/docs/latest/query-dsl/term/fuzzy/). :param top_k: Maximum number of Documents to return, defaults to 10 :param scale_score: If `True` scales the Document`s scores between 0 and 1, defaults to False :param all_terms_must_match: If `True` all terms in `query` must be present in the Document, defaults to False diff --git a/integrations/opensearch/tests/test_bm25_retriever.py b/integrations/opensearch/tests/test_bm25_retriever.py index ef3275608..48fc31419 100644 --- a/integrations/opensearch/tests/test_bm25_retriever.py +++ b/integrations/opensearch/tests/test_bm25_retriever.py @@ -121,6 +121,35 @@ def test_from_dict(_mock_opensearch_client): assert retriever._filter_policy == FilterPolicy.REPLACE +@patch("haystack_integrations.document_stores.opensearch.document_store.OpenSearch") +def test_from_dict_not_defaults(_mock_opensearch_client): + data = { + "type": "haystack_integrations.components.retrievers.opensearch.bm25_retriever.OpenSearchBM25Retriever", + "init_parameters": { + "document_store": { + "init_parameters": {"hosts": "some fake host", "index": "default"}, + "type": "haystack_integrations.document_stores.opensearch.document_store.OpenSearchDocumentStore", + }, + "filters": {}, + "fuzziness": 0, + "top_k": 15, + "scale_score": True, + "filter_policy": "replace", + "custom_query": {"some": "custom query"}, + "raise_on_failure": True, + }, + } + retriever = OpenSearchBM25Retriever.from_dict(data) + assert retriever._document_store + assert retriever._filters == {} + assert retriever._fuzziness == 0 + assert retriever._top_k == 15 + assert retriever._scale_score + assert retriever._filter_policy == FilterPolicy.REPLACE + assert retriever._custom_query == {"some": "custom query"} + assert retriever._raise_on_failure is True + + def test_run(): mock_store = Mock(spec=OpenSearchDocumentStore) mock_store._bm25_retrieval.return_value = [Document(content="Test doc")] From e89855781329b3d7f7b29e6ce7b75b461671ac91 Mon Sep 17 00:00:00 2001 From: HaystackBot Date: Thu, 12 Dec 2024 10:38:00 +0000 Subject: [PATCH 144/229] Update the changelog --- integrations/opensearch/CHANGELOG.md | 53 ++++++++++++++++++++++------ 1 file changed, 43 insertions(+), 10 deletions(-) diff --git a/integrations/opensearch/CHANGELOG.md b/integrations/opensearch/CHANGELOG.md index afd8a57c2..fef7b4bc3 100644 --- a/integrations/opensearch/CHANGELOG.md +++ b/integrations/opensearch/CHANGELOG.md @@ -1,5 +1,12 @@ # Changelog +## [integrations/opensearch-v1.2.0] - 2024-12-12 + +### 🧹 Chores + +- Update docstring and type of fuzziness (#1243) + + ## [integrations/opensearch-v1.1.0] - 2024-10-29 ### 🚀 Features @@ -14,16 +21,21 @@ - Do not retry tests in `hatch run test` command (#954) -### ⚙️ Miscellaneous Tasks +### ⚙️ CI + +- Adopt uv as installer (#1142) + +### 🧹 Chores - OpenSearch - remove legacy filter support (#1067) - Update changelog after removing legacy filters (#1083) - Update ruff linting scripts and settings (#1105) -- Adopt uv as installer (#1142) -### Docs +### 🌀 Miscellaneous +- Docs: Update OpenSearchEmbeddingRetriever docstrings (#947) - Update BM25 docstrings (#945) +- Chore: opensearch - ruff update, don't ruff tests (#988) ## [integrations/opensearch-v0.9.0] - 2024-08-01 @@ -31,6 +43,7 @@ - Support aws authentication with OpenSearchDocumentStore (#920) + ## [integrations/opensearch-v0.8.1] - 2024-07-15 ### 🚀 Features @@ -42,10 +55,14 @@ - `OpenSearch` - Fallback to default filter policy when deserializing retrievers without the init parameter (#895) -### ⚙️ Miscellaneous Tasks +### 🧹 Chores - Update ruff invocation to include check parameter (#853) +### 🌀 Miscellaneous + +- Chore: Minor retriever pydoc fix (#884) + ## [integrations/opensearch-v0.7.1] - 2024-06-27 ### 🐛 Bug Fixes @@ -53,6 +70,7 @@ - Serialization for custom_query in OpenSearch retrievers (#851) - Support legacy filters with OpenSearchDocumentStore (#850) + ## [integrations/opensearch-v0.7.0] - 2024-06-25 ### 🚀 Features @@ -67,8 +85,6 @@ - Fix order of API docs (#447) -This PR will also push the docs to Readme - ### 📚 Documentation - Update category slug (#442) @@ -76,13 +92,21 @@ This PR will also push the docs to Readme - Small consistency improvements (#536) - Disable-class-def (#556) -### ⚙️ Miscellaneous Tasks +### ⚙️ CI - Retry tests to reduce flakyness (#836) -### Opensearch +### 🌀 Miscellaneous - Generate API docs (#324) +- Make tests show coverage (#566) +- Refactor tests (#574) +- Fix opensearch errors bulk write (#594) +- Remove references to Python 3.7 (#601) +- [Elasticsearch] fix: Filters not working with metadata that contain a space or capitalization (#639) +- Chore: add license classifiers (#680) +- Chore: change the pydoc renderer class (#718) +- Ci: install `pytest-rerunfailures` where needed; add retry config to `test-cov` script (#845) ## [integrations/opensearch-v0.2.0] - 2024-01-17 @@ -94,11 +118,16 @@ This PR will also push the docs to Readme - Use `hatch_vcs` to manage integrations versioning (#103) +### 🌀 Miscellaneous + +- Fix opensearch test badge (#97) +- Move package under haystack_integrations/* (#212) + ## [integrations/opensearch-v0.1.1] - 2023-12-05 ### 🐛 Bug Fixes -- Fix import and increase version (#77) +- Document Stores: fix protocol import (#77) ## [integrations/opensearch-v0.1.0] - 2023-12-04 @@ -106,13 +135,17 @@ This PR will also push the docs to Readme - Fix license headers +### 🌀 Miscellaneous + +- Remove Document Store decorator (#76) + ## [integrations/opensearch-v0.0.2] - 2023-11-30 ### 🚀 Features - Extend OpenSearch params support (#70) -### Build +### 🌀 Miscellaneous - Bump OpenSearch integration version to 0.0.2 (#71) From 8a435d92db381b9288bab5ce254b26304fd22565 Mon Sep 17 00:00:00 2001 From: Tobias Wochinger Date: Fri, 13 Dec 2024 17:06:06 +0100 Subject: [PATCH 145/229] chore: add application name (#1245) * chore: add application name * fix parentheses to dataframe object --------- Co-authored-by: Mo Sriha <22803208+medsriha@users.noreply.github.com> --- .../snowflake/snowflake_table_retriever.py | 7 ++- .../tests/test_snowflake_table_retriever.py | 62 ++++++++++++++++++- 2 files changed, 67 insertions(+), 2 deletions(-) diff --git a/integrations/snowflake/src/haystack_integrations/components/retrievers/snowflake/snowflake_table_retriever.py b/integrations/snowflake/src/haystack_integrations/components/retrievers/snowflake/snowflake_table_retriever.py index aa6f5ff4d..3cbad3c9d 100644 --- a/integrations/snowflake/src/haystack_integrations/components/retrievers/snowflake/snowflake_table_retriever.py +++ b/integrations/snowflake/src/haystack_integrations/components/retrievers/snowflake/snowflake_table_retriever.py @@ -73,6 +73,7 @@ def __init__( db_schema: Optional[str] = None, warehouse: Optional[str] = None, login_timeout: Optional[int] = None, + application_name: Optional[str] = None, ) -> None: """ :param user: User's login. @@ -82,6 +83,7 @@ def __init__( :param db_schema: Name of the schema to use. :param warehouse: Name of the warehouse to use. :param login_timeout: Timeout in seconds for login. By default, 60 seconds. + :param application_name: Name of the application to use when connecting to Snowflake. """ self.user = user @@ -91,6 +93,7 @@ def __init__( self.db_schema = db_schema self.warehouse = warehouse self.login_timeout = login_timeout or 60 + self.application_name = application_name def to_dict(self) -> Dict[str, Any]: """ @@ -108,6 +111,7 @@ def to_dict(self) -> Dict[str, Any]: db_schema=self.db_schema, warehouse=self.warehouse, login_timeout=self.login_timeout, + application_name=self.application_name, ) @classmethod @@ -285,6 +289,7 @@ def _fetch_data( "schema": self.db_schema, "warehouse": self.warehouse, "login_timeout": self.login_timeout, + **({"application": self.application_name} if self.application_name else {}), } ) if conn is None: @@ -325,7 +330,7 @@ def run(self, query: str) -> Dict[str, Any]: if not query: logger.error("Provide a valid SQL query.") return { - "dataframe": pd.DataFrame, + "dataframe": pd.DataFrame(), "table": "", } else: diff --git a/integrations/snowflake/tests/test_snowflake_table_retriever.py b/integrations/snowflake/tests/test_snowflake_table_retriever.py index f5b8fee37..3e6e7d547 100644 --- a/integrations/snowflake/tests/test_snowflake_table_retriever.py +++ b/integrations/snowflake/tests/test_snowflake_table_retriever.py @@ -352,6 +352,64 @@ def test_run(self, mock_connect: MagicMock, snowflake_table_retriever: Snowflake assert result["dataframe"].equals(expected["dataframe"]) assert result["table"] == expected["table"] + mock_connect.assert_called_once_with( + user="test_user", + account="test_account", + password="test-api-key", + database="test_database", + schema="test_schema", + warehouse="test_warehouse", + login_timeout=30, + ) + + @patch( + "haystack_integrations.components.retrievers.snowflake.snowflake_table_retriever.snowflake.connector.connect" + ) + def test_run_with_application_name( + self, mock_connect: MagicMock, snowflake_table_retriever: SnowflakeTableRetriever + ) -> None: + snowflake_table_retriever.application_name = "test_application" + mock_conn = MagicMock() + mock_cursor = MagicMock() + mock_col1 = MagicMock() + mock_col2 = MagicMock() + mock_cursor.fetchall.side_effect = [ + [("DATETIME", "ROLE_NAME", "USER", "USER_NAME", "GRANTED_BY")], # User roles + [ + ( + "DATETIME", + "SELECT", + "TABLE", + "locations", + "ROLE", + "ROLE_NAME", + "GRANT_OPTION", + "GRANTED_BY", + ) + ], + ] + mock_col1.name = "City" + mock_col2.name = "State" + mock_cursor.description = [mock_col1, mock_col2] + + mock_cursor.fetchmany.return_value = [("Chicago", "Illinois")] + mock_conn.cursor.return_value = mock_cursor + mock_connect.return_value = mock_conn + + query = "SELECT * FROM locations" + + snowflake_table_retriever.run(query=query) + + mock_connect.assert_called_once_with( + user="test_user", + account="test_account", + password="test-api-key", + database="test_database", + schema="test_schema", + warehouse="test_warehouse", + login_timeout=30, + application="test_application", + ) @pytest.fixture def mock_chat_completion(self) -> Generator: @@ -494,6 +552,7 @@ def test_to_dict_default(self, monkeypatch: MagicMock) -> None: "db_schema": "test_schema", "warehouse": "test_warehouse", "login_timeout": 30, + "application_name": None, }, } @@ -508,6 +567,7 @@ def test_to_dict_with_parameters(self, monkeypatch: MagicMock) -> None: db_schema="SMALL_TOWNS", warehouse="COMPUTE_WH", login_timeout=30, + application_name="test_application", ) data = component.to_dict() @@ -529,6 +589,7 @@ def test_to_dict_with_parameters(self, monkeypatch: MagicMock) -> None: "db_schema": "SMALL_TOWNS", "warehouse": "COMPUTE_WH", "login_timeout": 30, + "application_name": "test_application", }, } @@ -605,7 +666,6 @@ def test_empty_query(self, snowflake_table_retriever: SnowflakeTableRetriever) - assert result.empty def test_serialization_deserialization_pipeline(self) -> None: - pipeline = Pipeline() pipeline.add_component("snow", SnowflakeTableRetriever(user="test_user", account="test_account")) pipeline.add_component("prompt_builder", PromptBuilder(template="Display results {{ table }}")) From 0e207915ef018ec5d9e3a00617b37806060a8299 Mon Sep 17 00:00:00 2001 From: HaystackBot Date: Fri, 13 Dec 2024 16:09:01 +0000 Subject: [PATCH 146/229] Update the changelog --- integrations/snowflake/CHANGELOG.md | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/integrations/snowflake/CHANGELOG.md b/integrations/snowflake/CHANGELOG.md index 757bfb3fe..356bbcace 100644 --- a/integrations/snowflake/CHANGELOG.md +++ b/integrations/snowflake/CHANGELOG.md @@ -1,13 +1,29 @@ # Changelog +## [integrations/snowflake-v0.0.3] - 2024-12-13 + +### ⚙️ CI + +- Adopt uv as installer (#1142) + +### 🧹 Chores + +- Update ruff linting scripts and settings (#1105) +- Add application name (#1245) + + ## [integrations/snowflake-v0.0.2] - 2024-09-25 ### 🚀 Features - Add Snowflake integration (#1064) -### ⚙️ Miscellaneous Tasks +### ⚙️ CI - Adding github workflow for Snowflake (#1097) +### 🌀 Miscellaneous + +- Docs: upd snowflake pydoc (#1102) + From 3a3419a21bcc08da7cdf52d033c89656112dad98 Mon Sep 17 00:00:00 2001 From: Amna Mubashar Date: Sat, 14 Dec 2024 23:30:53 +0100 Subject: [PATCH 147/229] ci: delete all azure_ai_search indexes (#1247) * Add a new fixture to clean up undeleted indexes --- .../azure_ai_search/tests/conftest.py | 33 +++++++++++++++---- 1 file changed, 27 insertions(+), 6 deletions(-) diff --git a/integrations/azure_ai_search/tests/conftest.py b/integrations/azure_ai_search/tests/conftest.py index e741e3066..02742031c 100644 --- a/integrations/azure_ai_search/tests/conftest.py +++ b/integrations/azure_ai_search/tests/conftest.py @@ -6,13 +6,10 @@ from azure.core.credentials import AzureKeyCredential from azure.core.exceptions import ResourceNotFoundError from azure.search.documents.indexes import SearchIndexClient -from haystack import logging from haystack.document_stores.types import DuplicatePolicy from haystack_integrations.document_stores.azure_ai_search import AzureAISearchDocumentStore -logger = logging.getLogger(__name__) - # This is the approximate time in seconds it takes for the documents to be available in Azure Search index SLEEP_TIME_IN_SECONDS = 10 @@ -78,8 +75,32 @@ def wait_for_index_deletion(client, index_name): try: client.delete_index(index_name) if not wait_for_index_deletion(client, index_name): - logger.error(f"Index {index_name} was not properly deleted.") + print(f"Index {index_name} was not properly deleted.") except ResourceNotFoundError: - logger.error(f"Index {index_name} was already deleted or not found.") + print(f"Index {index_name} was already deleted or not found.") except Exception as e: - logger.error(f"Unexpected error when deleting index {index_name}: {e}") + print(f"Unexpected error when deleting index {index_name}: {e}") + raise + + +@pytest.fixture(scope="session", autouse=True) +def cleanup_indexes(): + """ + Fixture to clean up all remaining indexes at the end of the test session. + Automatically runs after all tests. + """ + azure_endpoint = os.environ["AZURE_SEARCH_SERVICE_ENDPOINT"] + api_key = os.environ["AZURE_SEARCH_API_KEY"] + + client = SearchIndexClient(azure_endpoint, AzureKeyCredential(api_key)) + + yield # Allow tests to run before performing cleanup + + # Cleanup: Delete all remaining indexes + print("Starting session-level cleanup of all Azure Search indexes.") + existing_indexes = client.list_index_names() + for index in existing_indexes: + try: + client.delete_index(index) + except Exception as e: + print(f"Failed to delete index during clean up {index}: {e}") From 83f7f6990f55ca76cd522c06d68a011d2dddcfc3 Mon Sep 17 00:00:00 2001 From: Stefano Fiorucci Date: Wed, 18 Dec 2024 10:16:51 +0100 Subject: [PATCH 148/229] fix: fixes to Bedrock Chat Generator for compatibility with the new ChatMessage (#1250) --- .../amazon_bedrock/chat/chat_generator.py | 10 +++--- .../tests/test_chat_generator.py | 36 +++++++++---------- 2 files changed, 22 insertions(+), 24 deletions(-) diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py index 499fe1c24..bcf11414c 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py @@ -278,12 +278,10 @@ def extract_replies_from_response(self, response_body: Dict[str, Any]) -> List[C # Process each content block separately for content_block in content_blocks: if "text" in content_block: - replies.append(ChatMessage.from_assistant(content=content_block["text"], meta=base_meta.copy())) + replies.append(ChatMessage.from_assistant(content_block["text"], meta=base_meta.copy())) elif "toolUse" in content_block: replies.append( - ChatMessage.from_assistant( - content=json.dumps(content_block["toolUse"]), meta=base_meta.copy() - ) + ChatMessage.from_assistant(json.dumps(content_block["toolUse"]), meta=base_meta.copy()) ) return replies @@ -334,9 +332,9 @@ def process_streaming_response( pass tool_content = json.dumps(current_tool_use) - replies.append(ChatMessage.from_assistant(content=tool_content, meta=base_meta.copy())) + replies.append(ChatMessage.from_assistant(tool_content, meta=base_meta.copy())) elif current_content: - replies.append(ChatMessage.from_assistant(content=current_content, meta=base_meta.copy())) + replies.append(ChatMessage.from_assistant(current_content, meta=base_meta.copy())) elif "messageStop" in event: # not 100% correct for multiple messages but no way around it diff --git a/integrations/amazon_bedrock/tests/test_chat_generator.py b/integrations/amazon_bedrock/tests/test_chat_generator.py index 8eb29729c..c2122163c 100644 --- a/integrations/amazon_bedrock/tests/test_chat_generator.py +++ b/integrations/amazon_bedrock/tests/test_chat_generator.py @@ -163,9 +163,9 @@ def test_default_inference_params(self, model_name, chat_messages): first_reply = replies[0] assert isinstance(first_reply, ChatMessage), "First reply is not a ChatMessage instance" - assert first_reply.content, "First reply has no content" + assert first_reply.text, "First reply has no content" assert ChatMessage.is_from(first_reply, ChatRole.ASSISTANT), "First reply is not from the assistant" - assert "paris" in first_reply.content.lower(), "First reply does not contain 'paris'" + assert "paris" in first_reply.text.lower(), "First reply does not contain 'paris'" assert first_reply.meta, "First reply has no metadata" if first_reply.meta and "usage" in first_reply.meta: @@ -197,9 +197,9 @@ def streaming_callback(chunk: StreamingChunk): first_reply = replies[0] assert isinstance(first_reply, ChatMessage), "First reply is not a ChatMessage instance" - assert first_reply.content, "First reply has no content" + assert first_reply.text, "First reply has no content" assert ChatMessage.is_from(first_reply, ChatRole.ASSISTANT), "First reply is not from the assistant" - assert "paris" in first_reply.content.lower(), "First reply does not contain 'paris'" + assert "paris" in first_reply.text.lower(), "First reply does not contain 'paris'" assert first_reply.meta, "First reply has no metadata" @pytest.mark.parametrize("model_name", MODELS_TO_TEST_WITH_TOOLS) @@ -246,7 +246,7 @@ def test_tools_use(self, model_name): first_reply = replies[0] assert isinstance(first_reply, ChatMessage), "First reply is not a ChatMessage instance" - assert first_reply.content, "First reply has no content" + assert first_reply.text, "First reply has no content" assert ChatMessage.is_from(first_reply, ChatRole.ASSISTANT), "First reply is not from the assistant" assert first_reply.meta, "First reply has no metadata" @@ -254,9 +254,9 @@ def test_tools_use(self, model_name): if len(replies) > 1: second_reply = replies[1] assert isinstance(second_reply, ChatMessage), "Second reply is not a ChatMessage instance" - assert second_reply.content, "Second reply has no content" + assert second_reply.text, "Second reply has no content" assert ChatMessage.is_from(second_reply, ChatRole.ASSISTANT), "Second reply is not from the assistant" - tool_call = json.loads(second_reply.content) + tool_call = json.loads(second_reply.text) assert "toolUseId" in tool_call, "Tool call does not contain 'toolUseId' key" assert tool_call["name"] == "top_song", f"Tool call {tool_call} does not contain the correct 'name' value" assert "input" in tool_call, f"Tool call {tool_call} does not contain 'input' key" @@ -266,7 +266,7 @@ def test_tools_use(self, model_name): else: # case where the model returns the tool call as the first message # double check that the tool call is correct - tool_call = json.loads(first_reply.content) + tool_call = json.loads(first_reply.text) assert "toolUseId" in tool_call, "Tool call does not contain 'toolUseId' key" assert tool_call["name"] == "top_song", f"Tool call {tool_call} does not contain the correct 'name' value" assert "input" in tool_call, f"Tool call {tool_call} does not contain 'input' key" @@ -318,7 +318,7 @@ def test_tools_use_with_streaming(self, model_name): first_reply = replies[0] assert isinstance(first_reply, ChatMessage), "First reply is not a ChatMessage instance" - assert first_reply.content, "First reply has no content" + assert first_reply.text, "First reply has no content" assert ChatMessage.is_from(first_reply, ChatRole.ASSISTANT), "First reply is not from the assistant" assert first_reply.meta, "First reply has no metadata" @@ -326,9 +326,9 @@ def test_tools_use_with_streaming(self, model_name): if len(replies) > 1: second_reply = replies[1] assert isinstance(second_reply, ChatMessage), "Second reply is not a ChatMessage instance" - assert second_reply.content, "Second reply has no content" + assert second_reply.text, "Second reply has no content" assert ChatMessage.is_from(second_reply, ChatRole.ASSISTANT), "Second reply is not from the assistant" - tool_call = json.loads(second_reply.content) + tool_call = json.loads(second_reply.text) assert "toolUseId" in tool_call, "Tool call does not contain 'toolUseId' key" assert tool_call["name"] == "top_song", f"Tool call {tool_call} does not contain the correct 'name' value" assert "input" in tool_call, f"Tool call {tool_call} does not contain 'input' key" @@ -338,7 +338,7 @@ def test_tools_use_with_streaming(self, model_name): else: # case where the model returns the tool call as the first message # double check that the tool call is correct - tool_call = json.loads(first_reply.content) + tool_call = json.loads(first_reply.text) assert "toolUseId" in tool_call, "Tool call does not contain 'toolUseId' key" assert tool_call["name"] == "top_song", f"Tool call {tool_call} does not contain the correct 'name' value" assert "input" in tool_call, f"Tool call {tool_call} does not contain 'input' key" @@ -361,7 +361,7 @@ def test_extract_replies_from_response(self, mock_boto3_session): replies = generator.extract_replies_from_response(text_response) assert len(replies) == 1 - assert replies[0].content == "This is a test response" + assert replies[0].text == "This is a test response" assert replies[0].role == ChatRole.ASSISTANT assert replies[0].meta["model"] == "anthropic.claude-3-5-sonnet-20240620-v1:0" assert replies[0].meta["finish_reason"] == "complete" @@ -381,7 +381,7 @@ def test_extract_replies_from_response(self, mock_boto3_session): replies = generator.extract_replies_from_response(tool_response) assert len(replies) == 1 - tool_content = json.loads(replies[0].content) + tool_content = json.loads(replies[0].text) assert tool_content["toolUseId"] == "123" assert tool_content["name"] == "test_tool" assert tool_content["input"] == {"key": "value"} @@ -405,8 +405,8 @@ def test_extract_replies_from_response(self, mock_boto3_session): replies = generator.extract_replies_from_response(mixed_response) assert len(replies) == 2 - assert replies[0].content == "Let me help you with that. I'll use the search tool to find the answer." - tool_content = json.loads(replies[1].content) + assert replies[0].text == "Let me help you with that. I'll use the search tool to find the answer." + tool_content = json.loads(replies[1].text) assert tool_content["toolUseId"] == "456" assert tool_content["name"] == "search_tool" assert tool_content["input"] == {"query": "test"} @@ -446,13 +446,13 @@ def test_callback(chunk: StreamingChunk): # Verify final replies assert len(replies) == 2 # Check text reply - assert replies[0].content == "Let me help you." + assert replies[0].text == "Let me help you." assert replies[0].meta["model"] == "anthropic.claude-3-5-sonnet-20240620-v1:0" assert replies[0].meta["finish_reason"] == "complete" assert replies[0].meta["usage"] == {"prompt_tokens": 10, "completion_tokens": 20, "total_tokens": 30} # Check tool use reply - tool_content = json.loads(replies[1].content) + tool_content = json.loads(replies[1].text) assert tool_content["toolUseId"] == "123" assert tool_content["name"] == "search_tool" assert tool_content["input"] == {"query": "test"} From be3789f3ddba3b87b1ac8f7ffbe77f951803f0dc Mon Sep 17 00:00:00 2001 From: HaystackBot Date: Wed, 18 Dec 2024 09:17:58 +0000 Subject: [PATCH 149/229] Update the changelog --- integrations/amazon_bedrock/CHANGELOG.md | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/integrations/amazon_bedrock/CHANGELOG.md b/integrations/amazon_bedrock/CHANGELOG.md index 46eeea7b7..6a15d4ad2 100644 --- a/integrations/amazon_bedrock/CHANGELOG.md +++ b/integrations/amazon_bedrock/CHANGELOG.md @@ -1,5 +1,12 @@ # Changelog +## [integrations/amazon_bedrock-v2.1.1] - 2024-12-18 + +### 🐛 Bug Fixes + +- Fixes to Bedrock Chat Generator for compatibility with the new ChatMessage (#1250) + + ## [integrations/amazon_bedrock-v2.1.0] - 2024-12-11 ### 🚀 Features From b12461d40f907bdf5daf5862c16b4098aa3f9344 Mon Sep 17 00:00:00 2001 From: Stefano Fiorucci Date: Wed, 18 Dec 2024 15:41:16 +0100 Subject: [PATCH 150/229] fix: make Anthropic compatible with new `ChatMessage`; fix prompt caching tests (#1252) * make Anthropic compatible with new chatmessage; fix prompt caching tests * rm print --- .../generators/anthropic/chat/chat_generator.py | 13 ++++++++++--- integrations/anthropic/tests/test_chat_generator.py | 4 ++-- 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/integrations/anthropic/src/haystack_integrations/components/generators/anthropic/chat/chat_generator.py b/integrations/anthropic/src/haystack_integrations/components/generators/anthropic/chat/chat_generator.py index 43b50495c..56a740146 100644 --- a/integrations/anthropic/src/haystack_integrations/components/generators/anthropic/chat/chat_generator.py +++ b/integrations/anthropic/src/haystack_integrations/components/generators/anthropic/chat/chat_generator.py @@ -1,4 +1,3 @@ -import dataclasses import json from typing import Any, Callable, ClassVar, Dict, List, Optional, Union @@ -275,8 +274,16 @@ def _convert_to_anthropic_format(self, messages: List[ChatMessage]) -> List[Dict """ anthropic_formatted_messages = [] for m in messages: - message_dict = dataclasses.asdict(m) - formatted_message = {k: v for k, v in message_dict.items() if k in {"role", "content"} and v} + message_dict = m.to_dict() + formatted_message = {} + + # legacy format + if "role" in message_dict and "content" in message_dict: + formatted_message = {k: v for k, v in message_dict.items() if k in {"role", "content"} and v} + # new format + elif "_role" in message_dict and "_content" in message_dict: + formatted_message = {"role": m.role.value, "content": m.text} + if m.is_from(ChatRole.SYSTEM): # system messages are treated differently and MUST be in the format expected by the Anthropic API # remove role and content from the message dict, add type and text diff --git a/integrations/anthropic/tests/test_chat_generator.py b/integrations/anthropic/tests/test_chat_generator.py index 9a111fc9d..36622ecd9 100644 --- a/integrations/anthropic/tests/test_chat_generator.py +++ b/integrations/anthropic/tests/test_chat_generator.py @@ -428,5 +428,5 @@ def test_prompt_caching(self, cache_enabled): or token_usage.get("cache_read_input_tokens") > 1024 ) else: - assert "cache_creation_input_tokens" not in token_usage - assert "cache_read_input_tokens" not in token_usage + assert token_usage["cache_creation_input_tokens"] == 0 + assert token_usage["cache_read_input_tokens"] == 0 From 06e9c5615bdba322f769a69717cf2e5ff46bc494 Mon Sep 17 00:00:00 2001 From: HaystackBot Date: Wed, 18 Dec 2024 14:44:01 +0000 Subject: [PATCH 151/229] Update the changelog --- integrations/anthropic/CHANGELOG.md | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/integrations/anthropic/CHANGELOG.md b/integrations/anthropic/CHANGELOG.md index a7cdc7d09..c14a8032d 100644 --- a/integrations/anthropic/CHANGELOG.md +++ b/integrations/anthropic/CHANGELOG.md @@ -1,6 +1,10 @@ # Changelog -## [unreleased] +## [integrations/anthropic-v1.2.1] - 2024-12-18 + +### 🐛 Bug Fixes + +- Make Anthropic compatible with new `ChatMessage`; fix prompt caching tests (#1252) ### ⚙️ CI @@ -9,10 +13,13 @@ ### 🧹 Chores - Update ruff linting scripts and settings (#1105) +- Fix linting/isort (#1215) ### 🌀 Miscellaneous - Add AnthropicVertexChatGenerator component (#1192) +- Docs: add AnthropicVertexChatGenerator to pydoc (#1221) +- Chore: use `text` instead of `content` for `ChatMessage` in Cohere and Anthropic (#1237) ## [integrations/anthropic-v1.1.0] - 2024-09-20 From 7f62ca8f0b50352b0c99eae0ab0391c7305eef75 Mon Sep 17 00:00:00 2001 From: Stefano Fiorucci Date: Wed, 18 Dec 2024 16:21:28 +0100 Subject: [PATCH 152/229] ci: make nightly tests actually run with Haystack main branch (#1251) * fix nightly runs with haystack main * feedback --- .github/workflows/amazon_bedrock.yml | 2 +- .github/workflows/amazon_sagemaker.yml | 2 +- .github/workflows/anthropic.yml | 2 +- .github/workflows/astra.yml | 2 +- .github/workflows/azure_ai_search.yml | 2 +- .github/workflows/chroma.yml | 2 +- .github/workflows/cohere.yml | 2 +- .github/workflows/deepeval.yml | 2 +- .github/workflows/elasticsearch.yml | 2 +- .github/workflows/fastembed.yml | 2 +- .github/workflows/google_ai.yml | 2 +- .github/workflows/google_vertex.yml | 2 +- .github/workflows/instructor_embedders.yml | 2 +- .github/workflows/jina.yml | 2 +- .github/workflows/langfuse.yml | 2 +- .github/workflows/llama_cpp.yml | 2 +- .github/workflows/mistral.yml | 2 +- .github/workflows/mongodb_atlas.yml | 2 +- .github/workflows/nvidia.yml | 2 +- .github/workflows/ollama.yml | 2 +- .github/workflows/opensearch.yml | 2 +- .github/workflows/optimum.yml | 2 +- .github/workflows/pgvector.yml | 2 +- .github/workflows/pinecone.yml | 2 +- .github/workflows/qdrant.yml | 2 +- .github/workflows/ragas.yml | 2 +- .github/workflows/snowflake.yml | 2 +- .github/workflows/unstructured.yml | 2 +- .github/workflows/weaviate.yml | 2 +- 29 files changed, 29 insertions(+), 29 deletions(-) diff --git a/.github/workflows/amazon_bedrock.yml b/.github/workflows/amazon_bedrock.yml index 2057d4bdf..be356ee4c 100644 --- a/.github/workflows/amazon_bedrock.yml +++ b/.github/workflows/amazon_bedrock.yml @@ -81,7 +81,7 @@ jobs: - name: Nightly - run unit tests with Haystack main branch if: github.event_name == 'schedule' run: | - hatch run pip install git+https://github.com/deepset-ai/haystack.git + hatch run uv pip install git+https://github.com/deepset-ai/haystack.git@main hatch run cov-retry -m "not integration" - name: Send event to Datadog for nightly failures diff --git a/.github/workflows/amazon_sagemaker.yml b/.github/workflows/amazon_sagemaker.yml index ed0a571e6..646f22e02 100644 --- a/.github/workflows/amazon_sagemaker.yml +++ b/.github/workflows/amazon_sagemaker.yml @@ -62,7 +62,7 @@ jobs: - name: Nightly - run unit tests with Haystack main branch if: github.event_name == 'schedule' run: | - hatch run pip install git+https://github.com/deepset-ai/haystack.git + hatch run uv pip install git+https://github.com/deepset-ai/haystack.git@main hatch run cov-retry -m "not integration" - name: Send event to Datadog for nightly failures diff --git a/.github/workflows/anthropic.yml b/.github/workflows/anthropic.yml index 52ba5c9d4..6dcc681bc 100644 --- a/.github/workflows/anthropic.yml +++ b/.github/workflows/anthropic.yml @@ -59,7 +59,7 @@ jobs: - name: Nightly - run unit tests with Haystack main branch if: github.event_name == 'schedule' run: | - hatch run pip install git+https://github.com/deepset-ai/haystack.git + hatch run uv pip install git+https://github.com/deepset-ai/haystack.git@main hatch run cov-retry -m "not integration" - name: Send event to Datadog for nightly failures diff --git a/.github/workflows/astra.yml b/.github/workflows/astra.yml index dcfc00c75..55ff7552c 100644 --- a/.github/workflows/astra.yml +++ b/.github/workflows/astra.yml @@ -66,7 +66,7 @@ jobs: - name: Nightly - run unit tests with Haystack main branch if: github.event_name == 'schedule' run: | - hatch run pip install git+https://github.com/deepset-ai/haystack.git + hatch run uv pip install git+https://github.com/deepset-ai/haystack.git@main hatch run cov-retry -m "not integration" - name: Send event to Datadog for nightly failures diff --git a/.github/workflows/azure_ai_search.yml b/.github/workflows/azure_ai_search.yml index 1c10edc91..7ead7d113 100644 --- a/.github/workflows/azure_ai_search.yml +++ b/.github/workflows/azure_ai_search.yml @@ -60,7 +60,7 @@ jobs: - name: Nightly - run unit tests with Haystack main branch if: github.event_name == 'schedule' run: | - hatch run pip install git+https://github.com/deepset-ai/haystack.git + hatch run uv pip install git+https://github.com/deepset-ai/haystack.git@main hatch run cov-retry -m "not integration" - name: Send event to Datadog for nightly failures diff --git a/.github/workflows/chroma.yml b/.github/workflows/chroma.yml index 6dbf36d85..323a8a0b3 100644 --- a/.github/workflows/chroma.yml +++ b/.github/workflows/chroma.yml @@ -66,7 +66,7 @@ jobs: - name: Nightly - run unit tests with Haystack main branch if: github.event_name == 'schedule' run: | - hatch run pip install git+https://github.com/deepset-ai/haystack.git + hatch run uv pip install git+https://github.com/deepset-ai/haystack.git@main hatch run cov-retry -m "not integration" - name: Send event to Datadog for nightly failures diff --git a/.github/workflows/cohere.yml b/.github/workflows/cohere.yml index 00a8ee2ed..2f74cd66a 100644 --- a/.github/workflows/cohere.yml +++ b/.github/workflows/cohere.yml @@ -63,7 +63,7 @@ jobs: - name: Nightly - run unit tests with Haystack main branch if: github.event_name == 'schedule' run: | - hatch run pip install git+https://github.com/deepset-ai/haystack.git + hatch run uv pip install git+https://github.com/deepset-ai/haystack.git@main hatch run cov-retry -m "not integration" - name: Send event to Datadog for nightly failures diff --git a/.github/workflows/deepeval.yml b/.github/workflows/deepeval.yml index 23de1a3f4..3d320cb8f 100644 --- a/.github/workflows/deepeval.yml +++ b/.github/workflows/deepeval.yml @@ -63,7 +63,7 @@ jobs: - name: Nightly - run unit tests with Haystack main branch if: github.event_name == 'schedule' run: | - hatch run pip install git+https://github.com/deepset-ai/haystack.git + hatch run uv pip install git+https://github.com/deepset-ai/haystack.git@main hatch run cov-retry -m "not integration" - name: Send event to Datadog for nightly failures diff --git a/.github/workflows/elasticsearch.yml b/.github/workflows/elasticsearch.yml index 476e832b5..2f34d7d19 100644 --- a/.github/workflows/elasticsearch.yml +++ b/.github/workflows/elasticsearch.yml @@ -60,7 +60,7 @@ jobs: - name: Nightly - run unit tests with Haystack main branch if: github.event_name == 'schedule' run: | - hatch run pip install git+https://github.com/deepset-ai/haystack.git + hatch run uv pip install git+https://github.com/deepset-ai/haystack.git@main hatch run cov-retry -m "not integration" - name: Send event to Datadog for nightly failures diff --git a/.github/workflows/fastembed.yml b/.github/workflows/fastembed.yml index e389bf3a4..431f947dc 100644 --- a/.github/workflows/fastembed.yml +++ b/.github/workflows/fastembed.yml @@ -47,7 +47,7 @@ jobs: - name: Nightly - run unit tests with Haystack main branch if: github.event_name == 'schedule' run: | - hatch run pip install git+https://github.com/deepset-ai/haystack.git + hatch run uv pip install git+https://github.com/deepset-ai/haystack.git@main hatch run cov-retry -m "not integration" - name: Send event to Datadog for nightly failures diff --git a/.github/workflows/google_ai.yml b/.github/workflows/google_ai.yml index 1b4b2e496..25ee0d020 100644 --- a/.github/workflows/google_ai.yml +++ b/.github/workflows/google_ai.yml @@ -63,7 +63,7 @@ jobs: - name: Nightly - run unit tests with Haystack main branch if: github.event_name == 'schedule' run: | - hatch run pip install git+https://github.com/deepset-ai/haystack.git + hatch run uv pip install git+https://github.com/deepset-ai/haystack.git@main hatch run cov-retry -m "not integration" - name: Send event to Datadog for nightly failures diff --git a/.github/workflows/google_vertex.yml b/.github/workflows/google_vertex.yml index 34c0cf07c..da9f6d7b3 100644 --- a/.github/workflows/google_vertex.yml +++ b/.github/workflows/google_vertex.yml @@ -62,7 +62,7 @@ jobs: - name: Nightly - run unit tests with Haystack main branch if: github.event_name == 'schedule' run: | - hatch run pip install git+https://github.com/deepset-ai/haystack.git + hatch run uv pip install git+https://github.com/deepset-ai/haystack.git@main hatch run cov-retry -m "not integration" - name: Send event to Datadog for nightly failures diff --git a/.github/workflows/instructor_embedders.yml b/.github/workflows/instructor_embedders.yml index f12f4d696..01e9798c3 100644 --- a/.github/workflows/instructor_embedders.yml +++ b/.github/workflows/instructor_embedders.yml @@ -40,7 +40,7 @@ jobs: - name: Nightly - run unit tests with Haystack main branch if: github.event_name == 'schedule' run: | - hatch run pip install git+https://github.com/deepset-ai/haystack.git + hatch run uv pip install git+https://github.com/deepset-ai/haystack.git@main hatch run cov-retry -m "not integration" - name: Send event to Datadog for nightly failures diff --git a/.github/workflows/jina.yml b/.github/workflows/jina.yml index 00af6eb45..c6d506714 100644 --- a/.github/workflows/jina.yml +++ b/.github/workflows/jina.yml @@ -62,7 +62,7 @@ jobs: - name: Nightly - run unit tests with Haystack main branch if: github.event_name == 'schedule' run: | - hatch run pip install git+https://github.com/deepset-ai/haystack.git + hatch run uv pip install git+https://github.com/deepset-ai/haystack.git@main hatch run cov-retry -m "not integration" - name: Send event to Datadog for nightly failures diff --git a/.github/workflows/langfuse.yml b/.github/workflows/langfuse.yml index 8a10cf241..0cac22f91 100644 --- a/.github/workflows/langfuse.yml +++ b/.github/workflows/langfuse.yml @@ -65,7 +65,7 @@ jobs: - name: Nightly - run unit tests with Haystack main branch if: github.event_name == 'schedule' run: | - hatch run pip install git+https://github.com/deepset-ai/haystack.git + hatch run uv pip install git+https://github.com/deepset-ai/haystack.git@main hatch run cov-retry -m "not integration" - name: Send event to Datadog for nightly failures diff --git a/.github/workflows/llama_cpp.yml b/.github/workflows/llama_cpp.yml index a9480ca96..b3ec462d3 100644 --- a/.github/workflows/llama_cpp.yml +++ b/.github/workflows/llama_cpp.yml @@ -62,7 +62,7 @@ jobs: - name: Nightly - run unit tests with Haystack main branch if: github.event_name == 'schedule' run: | - hatch run pip install git+https://github.com/deepset-ai/haystack.git + hatch run uv pip install git+https://github.com/deepset-ai/haystack.git@main hatch run cov-retry -m "not integration" - name: Send event to Datadog for nightly failures diff --git a/.github/workflows/mistral.yml b/.github/workflows/mistral.yml index e62008906..f0ce2f75a 100644 --- a/.github/workflows/mistral.yml +++ b/.github/workflows/mistral.yml @@ -63,7 +63,7 @@ jobs: - name: Nightly - run unit tests with Haystack main branch if: github.event_name == 'schedule' run: | - hatch run pip install git+https://github.com/deepset-ai/haystack.git + hatch run uv pip install git+https://github.com/deepset-ai/haystack.git@main hatch run cov-retry -m "not integration" - name: Send event to Datadog for nightly failures diff --git a/.github/workflows/mongodb_atlas.yml b/.github/workflows/mongodb_atlas.yml index 3fd2a43ac..690d52cd4 100644 --- a/.github/workflows/mongodb_atlas.yml +++ b/.github/workflows/mongodb_atlas.yml @@ -60,7 +60,7 @@ jobs: - name: Nightly - run unit tests with Haystack main branch if: github.event_name == 'schedule' run: | - hatch run pip install git+https://github.com/deepset-ai/haystack.git + hatch run uv pip install git+https://github.com/deepset-ai/haystack.git@main hatch run cov-retry -m "not integration" - name: Send event to Datadog for nightly failures diff --git a/.github/workflows/nvidia.yml b/.github/workflows/nvidia.yml index 0d39a4d91..39a5d062e 100644 --- a/.github/workflows/nvidia.yml +++ b/.github/workflows/nvidia.yml @@ -64,7 +64,7 @@ jobs: - name: Nightly - run unit tests with Haystack main branch if: github.event_name == 'schedule' run: | - hatch run pip install git+https://github.com/deepset-ai/haystack.git + hatch run uv pip install git+https://github.com/deepset-ai/haystack.git@main hatch run cov-retry -m "not integration" - name: Send event to Datadog for nightly failures diff --git a/.github/workflows/ollama.yml b/.github/workflows/ollama.yml index 43af485b7..8be6bf263 100644 --- a/.github/workflows/ollama.yml +++ b/.github/workflows/ollama.yml @@ -80,7 +80,7 @@ jobs: - name: Nightly - run unit tests with Haystack main branch if: github.event_name == 'schedule' run: | - hatch run pip install git+https://github.com/deepset-ai/haystack.git + hatch run uv pip install git+https://github.com/deepset-ai/haystack.git@main hatch run cov-retry -m "not integration" - name: Send event to Datadog for nightly failures diff --git a/.github/workflows/opensearch.yml b/.github/workflows/opensearch.yml index 48169a75f..32c9c4e6e 100644 --- a/.github/workflows/opensearch.yml +++ b/.github/workflows/opensearch.yml @@ -60,7 +60,7 @@ jobs: - name: Nightly - run unit tests with Haystack main branch if: github.event_name == 'schedule' run: | - hatch run pip install git+https://github.com/deepset-ai/haystack.git + hatch run uv pip install git+https://github.com/deepset-ai/haystack.git@main hatch run cov-retry -m "not integration" - name: Send event to Datadog for nightly failures diff --git a/.github/workflows/optimum.yml b/.github/workflows/optimum.yml index c33baa7f8..eb7b877d3 100644 --- a/.github/workflows/optimum.yml +++ b/.github/workflows/optimum.yml @@ -62,7 +62,7 @@ jobs: - name: Nightly - run unit tests with Haystack main branch if: github.event_name == 'schedule' run: | - hatch run pip install git+https://github.com/deepset-ai/haystack.git + hatch run uv pip install git+https://github.com/deepset-ai/haystack.git@main hatch run cov-retry -m "not integration" - name: Send event to Datadog for nightly failures diff --git a/.github/workflows/pgvector.yml b/.github/workflows/pgvector.yml index ab5c984ed..d1727e1fc 100644 --- a/.github/workflows/pgvector.yml +++ b/.github/workflows/pgvector.yml @@ -66,7 +66,7 @@ jobs: - name: Nightly - run unit tests with Haystack main branch if: github.event_name == 'schedule' run: | - hatch run pip install git+https://github.com/deepset-ai/haystack.git + hatch run uv pip install git+https://github.com/deepset-ai/haystack.git@main hatch run cov-retry -m "not integration" - name: Send event to Datadog for nightly failures diff --git a/.github/workflows/pinecone.yml b/.github/workflows/pinecone.yml index 9e143005b..7e34c70de 100644 --- a/.github/workflows/pinecone.yml +++ b/.github/workflows/pinecone.yml @@ -68,7 +68,7 @@ jobs: - name: Nightly - run unit tests with Haystack main branch if: github.event_name == 'schedule' run: | - hatch run pip install git+https://github.com/deepset-ai/haystack.git + hatch run uv pip install git+https://github.com/deepset-ai/haystack.git@main hatch run cov-retry -m "not integration" - name: Send event to Datadog for nightly failures diff --git a/.github/workflows/qdrant.yml b/.github/workflows/qdrant.yml index 116225b2d..2bb24d774 100644 --- a/.github/workflows/qdrant.yml +++ b/.github/workflows/qdrant.yml @@ -62,7 +62,7 @@ jobs: - name: Nightly - run unit tests with Haystack main branch if: github.event_name == 'schedule' run: | - hatch run pip install git+https://github.com/deepset-ai/haystack.git + hatch run uv pip install git+https://github.com/deepset-ai/haystack.git@main hatch run cov-retry -m "not integration" - name: Send event to Datadog for nightly failures diff --git a/.github/workflows/ragas.yml b/.github/workflows/ragas.yml index c4757e704..953f1ffee 100644 --- a/.github/workflows/ragas.yml +++ b/.github/workflows/ragas.yml @@ -63,7 +63,7 @@ jobs: - name: Nightly - run unit tests with Haystack main branch if: github.event_name == 'schedule' run: | - hatch run pip install git+https://github.com/deepset-ai/haystack.git + hatch run uv pip install git+https://github.com/deepset-ai/haystack.git@main hatch run cov-retry -m "not integration" - name: Send event to Datadog for nightly failures diff --git a/.github/workflows/snowflake.yml b/.github/workflows/snowflake.yml index 19596f312..98b30fc97 100644 --- a/.github/workflows/snowflake.yml +++ b/.github/workflows/snowflake.yml @@ -62,7 +62,7 @@ jobs: - name: Nightly - run unit tests with Haystack main branch if: github.event_name == 'schedule' run: | - hatch run pip install git+https://github.com/deepset-ai/haystack.git + hatch run uv pip install git+https://github.com/deepset-ai/haystack.git@main hatch run cov-retry -m "not integration" - name: Send event to Datadog for nightly failures diff --git a/.github/workflows/unstructured.yml b/.github/workflows/unstructured.yml index e4b640275..5a9aa77a3 100644 --- a/.github/workflows/unstructured.yml +++ b/.github/workflows/unstructured.yml @@ -74,7 +74,7 @@ jobs: - name: Nightly - run unit tests with Haystack main branch if: github.event_name == 'schedule' run: | - hatch run pip install git+https://github.com/deepset-ai/haystack.git + hatch run uv pip install git+https://github.com/deepset-ai/haystack.git@main hatch run cov-retry -m "not integration" - name: Send event to Datadog for nightly failures diff --git a/.github/workflows/weaviate.yml b/.github/workflows/weaviate.yml index 36c30f069..4368e62c0 100644 --- a/.github/workflows/weaviate.yml +++ b/.github/workflows/weaviate.yml @@ -60,7 +60,7 @@ jobs: - name: Nightly - run unit tests with Haystack main branch if: github.event_name == 'schedule' run: | - hatch run pip install git+https://github.com/deepset-ai/haystack.git + hatch run uv pip install git+https://github.com/deepset-ai/haystack.git@main hatch run cov-retry -m "not integration" - name: Send event to Datadog for nightly failures From aff674850a69ff7277f823d06dfb20f82e2cfdd2 Mon Sep 17 00:00:00 2001 From: Stefano Fiorucci Date: Wed, 18 Dec 2024 18:29:20 +0100 Subject: [PATCH 153/229] fix: make Ollama Chat Generator compatible with new ChatMessage (#1256) --- .../components/generators/ollama/chat/chat_generator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/integrations/ollama/src/haystack_integrations/components/generators/ollama/chat/chat_generator.py b/integrations/ollama/src/haystack_integrations/components/generators/ollama/chat/chat_generator.py index f598a6e42..daae552e5 100644 --- a/integrations/ollama/src/haystack_integrations/components/generators/ollama/chat/chat_generator.py +++ b/integrations/ollama/src/haystack_integrations/components/generators/ollama/chat/chat_generator.py @@ -116,7 +116,7 @@ def _build_message_from_ollama_response(self, ollama_response: ChatResponse) -> Converts the non-streaming response from the Ollama API to a ChatMessage. """ response_dict = ollama_response.model_dump() - message = ChatMessage.from_assistant(content=response_dict["message"]["content"]) + message = ChatMessage.from_assistant(response_dict["message"]["content"]) message.meta.update({key: value for key, value in response_dict.items() if key != "message"}) return message From 4c478afd109939a0d989f932f1590b89256a6bf7 Mon Sep 17 00:00:00 2001 From: HaystackBot Date: Wed, 18 Dec 2024 17:31:34 +0000 Subject: [PATCH 154/229] Update the changelog --- integrations/ollama/CHANGELOG.md | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/integrations/ollama/CHANGELOG.md b/integrations/ollama/CHANGELOG.md index c28767257..7aeb70021 100644 --- a/integrations/ollama/CHANGELOG.md +++ b/integrations/ollama/CHANGELOG.md @@ -1,5 +1,12 @@ # Changelog +## [integrations/ollama-v2.1.2] - 2024-12-18 + +### 🐛 Bug Fixes + +- Make Ollama Chat Generator compatible with new ChatMessage (#1256) + + ## [integrations/ollama-v2.1.1] - 2024-12-10 ### 🌀 Miscellaneous From 1aba3079577ba062caa87c95cdc1cee50b545e10 Mon Sep 17 00:00:00 2001 From: Stefano Fiorucci Date: Thu, 19 Dec 2024 10:07:37 +0100 Subject: [PATCH 155/229] fix: make llama.cpp Chat Generator compatible with new `ChatMessage` (#1254) * progress * remove vertex changes from this PR * fix --- .../llama_cpp/chat/chat_generator.py | 51 +++++++++++-------- .../llama_cpp/tests/test_chat_generator.py | 10 ++-- 2 files changed, 35 insertions(+), 26 deletions(-) diff --git a/integrations/llama_cpp/src/haystack_integrations/components/generators/llama_cpp/chat/chat_generator.py b/integrations/llama_cpp/src/haystack_integrations/components/generators/llama_cpp/chat/chat_generator.py index 014dd7169..d2150f61f 100644 --- a/integrations/llama_cpp/src/haystack_integrations/components/generators/llama_cpp/chat/chat_generator.py +++ b/integrations/llama_cpp/src/haystack_integrations/components/generators/llama_cpp/chat/chat_generator.py @@ -2,7 +2,7 @@ from typing import Any, Dict, List, Optional from haystack import component -from haystack.dataclasses import ChatMessage, ChatRole +from haystack.dataclasses import ChatMessage from llama_cpp import Llama from llama_cpp.llama_tokenizer import LlamaHFTokenizer @@ -21,6 +21,10 @@ def _convert_message_to_llamacpp_format(message: ChatMessage) -> Dict[str, str]: if message.name: formatted_msg["name"] = message.name + if formatted_msg["role"] == "tool": + formatted_msg["name"] = message.tool_call_result.origin.tool_name + formatted_msg["content"] = message.tool_call_result.result + return formatted_msg @@ -114,26 +118,31 @@ def run(self, messages: List[ChatMessage], generation_kwargs: Optional[Dict[str, formatted_messages = [_convert_message_to_llamacpp_format(msg) for msg in messages] response = self.model.create_chat_completion(messages=formatted_messages, **updated_generation_kwargs) - replies = [ - ChatMessage( - content=choice["message"]["content"], - role=ChatRole[choice["message"]["role"].upper()], - name=None, - meta={ - "response_id": response["id"], - "model": response["model"], - "created": response["created"], - "index": choice["index"], - "finish_reason": choice["finish_reason"], - "usage": response["usage"], - }, - ) - for choice in response["choices"] - ] - - for reply, choice in zip(replies, response["choices"]): + + replies = [] + + for choice in response["choices"]: + meta = { + "response_id": response["id"], + "model": response["model"], + "created": response["created"], + "index": choice["index"], + "finish_reason": choice["finish_reason"], + "usage": response["usage"], + } + + name = None tool_calls = choice.get("message", {}).get("tool_calls", []) if tool_calls: - reply.meta["tool_calls"] = tool_calls - reply.name = tool_calls[0]["function"]["name"] if tool_calls else None + meta["tool_calls"] = tool_calls + name = tool_calls[0]["function"]["name"] + + reply = ChatMessage.from_assistant(choice["message"]["content"], meta=meta) + if name: + if hasattr(reply, "_name"): + reply._name = name # new ChatMessage + elif hasattr(reply, "name"): + reply.name = name # legacy ChatMessage + replies.append(reply) + return {"replies": replies} diff --git a/integrations/llama_cpp/tests/test_chat_generator.py b/integrations/llama_cpp/tests/test_chat_generator.py index 0ddd78c4f..87639f684 100644 --- a/integrations/llama_cpp/tests/test_chat_generator.py +++ b/integrations/llama_cpp/tests/test_chat_generator.py @@ -41,11 +41,11 @@ def test_convert_message_to_llamacpp_format(): assert _convert_message_to_llamacpp_format(message) == {"role": "user", "content": "I have a question"} message = ChatMessage.from_function("Function call", "function_name") - assert _convert_message_to_llamacpp_format(message) == { - "role": "function", - "content": "Function call", - "name": "function_name", - } + converted_message = _convert_message_to_llamacpp_format(message) + + assert converted_message["role"] in ("function", "tool") + assert converted_message["name"] == "function_name" + assert converted_message["content"] == "Function call" class TestLlamaCppChatGenerator: From 4f067b9b183020d039be017990eb7f868e116099 Mon Sep 17 00:00:00 2001 From: HaystackBot Date: Thu, 19 Dec 2024 09:09:30 +0000 Subject: [PATCH 156/229] Update the changelog --- integrations/llama_cpp/CHANGELOG.md | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/integrations/llama_cpp/CHANGELOG.md b/integrations/llama_cpp/CHANGELOG.md index 2d4a8c86e..930486a0d 100644 --- a/integrations/llama_cpp/CHANGELOG.md +++ b/integrations/llama_cpp/CHANGELOG.md @@ -1,5 +1,12 @@ # Changelog +## [integrations/llama_cpp-v0.4.3] - 2024-12-19 + +### 🐛 Bug Fixes + +- Make llama.cpp Chat Generator compatible with new `ChatMessage` (#1254) + + ## [integrations/llama_cpp-v0.4.2] - 2024-12-10 ### 🧪 Testing From 07ca7a67aaa4db518090cdd42b0fa87864640417 Mon Sep 17 00:00:00 2001 From: Stefano Fiorucci Date: Thu, 19 Dec 2024 12:44:21 +0100 Subject: [PATCH 157/229] fix: make Google Vertex Chat Generator compatible with new ChatMessage (#1255) * make Vertex compatible with new ChatMessage * fmt --- .../components/generators/google_vertex/chat/gemini.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/chat/gemini.py b/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/chat/gemini.py index 2309ca718..845e24f5f 100644 --- a/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/chat/gemini.py +++ b/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/chat/gemini.py @@ -280,12 +280,10 @@ def _get_response(self, response_body: GenerationResponse) -> List[ChatMessage]: # Remove content from metadata metadata.pop("content", None) if part._raw_part.text != "": - replies.append(ChatMessage.from_assistant(content=part._raw_part.text, meta=metadata)) + replies.append(ChatMessage.from_assistant(part._raw_part.text, meta=metadata)) elif part.function_call: metadata["function_call"] = part.function_call - new_message = ChatMessage.from_assistant( - content=json.dumps(dict(part.function_call.args)), meta=metadata - ) + new_message = ChatMessage.from_assistant(json.dumps(dict(part.function_call.args)), meta=metadata) new_message.name = part.function_call.name replies.append(new_message) return replies From e35d3cb1688c9195ea39409fbda774ad1706a9ca Mon Sep 17 00:00:00 2001 From: HaystackBot Date: Thu, 19 Dec 2024 11:45:36 +0000 Subject: [PATCH 158/229] Update the changelog --- integrations/google_vertex/CHANGELOG.md | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/integrations/google_vertex/CHANGELOG.md b/integrations/google_vertex/CHANGELOG.md index 71f433509..7cf633ebd 100644 --- a/integrations/google_vertex/CHANGELOG.md +++ b/integrations/google_vertex/CHANGELOG.md @@ -1,5 +1,12 @@ # Changelog +## [integrations/google_vertex-v4.0.1] - 2024-12-19 + +### 🐛 Bug Fixes + +- Make Google Vertex Chat Generator compatible with new ChatMessage (#1255) + + ## [integrations/google_vertex-v4.0.0] - 2024-12-11 ### 🐛 Bug Fixes From 58cb13522cfe9e334b478c346a1f85e238f51a17 Mon Sep 17 00:00:00 2001 From: Stefano Fiorucci Date: Thu, 19 Dec 2024 15:01:16 +0100 Subject: [PATCH 159/229] fix: make GoogleAI Chat Generator compatible with new `ChatMessage`; small fixes to Cohere tests (#1253) * draft * improvements * small improvemtn * rm duplication * simplification --- .../tests/test_cohere_chat_generator.py | 2 +- .../generators/google_ai/chat/gemini.py | 32 +++++++++++++++---- .../tests/generators/chat/test_chat_gemini.py | 18 +++++------ 3 files changed, 35 insertions(+), 17 deletions(-) diff --git a/integrations/cohere/tests/test_cohere_chat_generator.py b/integrations/cohere/tests/test_cohere_chat_generator.py index 09f3708eb..4aaa2da2b 100644 --- a/integrations/cohere/tests/test_cohere_chat_generator.py +++ b/integrations/cohere/tests/test_cohere_chat_generator.py @@ -27,7 +27,7 @@ def streaming_chunk(text: str): @pytest.fixture def chat_messages(): - return [ChatMessage.from_assistant(content="What's the capital of France")] + return [ChatMessage.from_assistant("What's the capital of France")] class TestCohereChatGenerator: diff --git a/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/chat/gemini.py b/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/chat/gemini.py index 089b38b10..69f168a6b 100644 --- a/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/chat/gemini.py +++ b/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/chat/gemini.py @@ -247,6 +247,11 @@ def _message_to_part(self, message: ChatMessage) -> Part: p.function_response.name = message.name p.function_response.response = message.text return p + elif "TOOL" in ChatRole._member_names_ and message.is_from(ChatRole.TOOL): + p = Part() + p.function_response.name = message.tool_call_result.origin.tool_name + p.function_response.response = message.tool_call_result.result + return p elif message.is_from(ChatRole.USER): return self._convert_part(message.text) @@ -266,10 +271,17 @@ def _message_to_content(self, message: ChatMessage) -> Content: part.function_response.response = message.text elif message.is_from(ChatRole.USER): part = self._convert_part(message.text) + elif "TOOL" in ChatRole._member_names_ and message.is_from(ChatRole.TOOL): + part = Part() + part.function_response.name = message.tool_call_result.origin.tool_name + part.function_response.response = message.tool_call_result.result else: msg = f"Unsupported message role {message.role}" raise ValueError(msg) - role = "user" if message.is_from(ChatRole.USER) or message.is_from(ChatRole.FUNCTION) else "model" + + role = "user" + if message.is_from(ChatRole.ASSISTANT) or message.is_from(ChatRole.SYSTEM): + role = "model" return Content(parts=[part], role=role) @component.output_types(replies=List[ChatMessage]) @@ -335,13 +347,16 @@ def _get_response(self, response_body: GenerateContentResponse) -> List[ChatMess for part in candidate.content.parts: if part.text != "": - replies.append(ChatMessage.from_assistant(content=part.text, meta=candidate_metadata)) + replies.append(ChatMessage.from_assistant(part.text, meta=candidate_metadata)) elif part.function_call: candidate_metadata["function_call"] = part.function_call new_message = ChatMessage.from_assistant( - content=json.dumps(dict(part.function_call.args)), meta=candidate_metadata + json.dumps(dict(part.function_call.args)), meta=candidate_metadata ) - new_message.name = part.function_call.name + try: + new_message.name = part.function_call.name + except AttributeError: + new_message._name = part.function_call.name replies.append(new_message) return replies @@ -364,12 +379,15 @@ def _get_stream_response( for part in candidate["content"]["parts"]: if "text" in part and part["text"] != "": content = part["text"] - replies.append(ChatMessage.from_assistant(content=content, meta=metadata)) + replies.append(ChatMessage.from_assistant(content, meta=metadata)) elif "function_call" in part and len(part["function_call"]) > 0: metadata["function_call"] = part["function_call"] content = json.dumps(dict(part["function_call"]["args"])) - new_message = ChatMessage.from_assistant(content=content, meta=metadata) - new_message.name = part["function_call"]["name"] + new_message = ChatMessage.from_assistant(content, meta=metadata) + try: + new_message.name = part["function_call"]["name"] + except AttributeError: + new_message._name = part["function_call"]["name"] replies.append(new_message) streaming_callback(StreamingChunk(content=content, meta=metadata)) diff --git a/integrations/google_ai/tests/generators/chat/test_chat_gemini.py b/integrations/google_ai/tests/generators/chat/test_chat_gemini.py index b8658a4dd..0683bf21a 100644 --- a/integrations/google_ai/tests/generators/chat/test_chat_gemini.py +++ b/integrations/google_ai/tests/generators/chat/test_chat_gemini.py @@ -215,7 +215,7 @@ def get_current_weather(location: str, unit: str = "celsius"): # noqa: ARG001 tool = Tool(function_declarations=[get_current_weather_func]) gemini_chat = GoogleAIGeminiChatGenerator(model="gemini-pro", tools=[tool]) - messages = [ChatMessage.from_user(content="What is the temperature in celsius in Berlin?")] + messages = [ChatMessage.from_user("What is the temperature in celsius in Berlin?")] response = gemini_chat.run(messages=messages) assert "replies" in response assert len(response["replies"]) > 0 @@ -227,7 +227,7 @@ def get_current_weather(location: str, unit: str = "celsius"): # noqa: ARG001 assert json.loads(chat_message.text) == {"location": "Berlin", "unit": "celsius"} weather = get_current_weather(**json.loads(chat_message.text)) - messages += response["replies"] + [ChatMessage.from_function(content=weather, name="get_current_weather")] + messages += response["replies"] + [ChatMessage.from_function(weather, name="get_current_weather")] response = gemini_chat.run(messages=messages) assert "replies" in response assert len(response["replies"]) > 0 @@ -260,7 +260,7 @@ def get_current_weather(location: str, unit: str = "celsius"): # noqa: ARG001 tool = Tool(function_declarations=[get_current_weather_func]) gemini_chat = GoogleAIGeminiChatGenerator(model="gemini-pro", tools=[tool], streaming_callback=streaming_callback) - messages = [ChatMessage.from_user(content="What is the temperature in celsius in Berlin?")] + messages = [ChatMessage.from_user("What is the temperature in celsius in Berlin?")] response = gemini_chat.run(messages=messages) assert "replies" in response assert len(response["replies"]) > 0 @@ -272,8 +272,8 @@ def get_current_weather(location: str, unit: str = "celsius"): # noqa: ARG001 assert "function_call" in chat_message.meta assert json.loads(chat_message.text) == {"location": "Berlin", "unit": "celsius"} - weather = get_current_weather(**json.loads(response["replies"][0].text)) - messages += response["replies"] + [ChatMessage.from_function(content=weather, name="get_current_weather")] + weather = get_current_weather(**json.loads(chat_message.text)) + messages += response["replies"] + [ChatMessage.from_function(weather, name="get_current_weather")] response = gemini_chat.run(messages=messages) assert "replies" in response assert len(response["replies"]) > 0 @@ -289,10 +289,10 @@ def get_current_weather(location: str, unit: str = "celsius"): # noqa: ARG001 def test_past_conversation(): gemini_chat = GoogleAIGeminiChatGenerator(model="gemini-pro") messages = [ - ChatMessage.from_system(content="You are a knowledageable mathematician."), - ChatMessage.from_user(content="What is 2+2?"), - ChatMessage.from_assistant(content="It's an arithmetic operation."), - ChatMessage.from_user(content="Yeah, but what's the result?"), + ChatMessage.from_system("You are a knowledageable mathematician."), + ChatMessage.from_user("What is 2+2?"), + ChatMessage.from_assistant("It's an arithmetic operation."), + ChatMessage.from_user("Yeah, but what's the result?"), ] response = gemini_chat.run(messages=messages) assert "replies" in response From 6b8d02d4dd5de710f4e20981b2da525194a5d462 Mon Sep 17 00:00:00 2001 From: HaystackBot Date: Thu, 19 Dec 2024 14:02:46 +0000 Subject: [PATCH 160/229] Update the changelog --- integrations/google_ai/CHANGELOG.md | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/integrations/google_ai/CHANGELOG.md b/integrations/google_ai/CHANGELOG.md index 404303412..71cdf4e74 100644 --- a/integrations/google_ai/CHANGELOG.md +++ b/integrations/google_ai/CHANGELOG.md @@ -1,5 +1,12 @@ # Changelog +## [integrations/google_ai-v4.0.1] - 2024-12-19 + +### 🐛 Bug Fixes + +- Make GoogleAI Chat Generator compatible with new `ChatMessage`; small fixes to Cohere tests (#1253) + + ## [integrations/google_ai-v4.0.0] - 2024-12-10 ### 🐛 Bug Fixes From bbf7417c709cd9a7e2f81fb63ef446d72019e26d Mon Sep 17 00:00:00 2001 From: Vedant Yadav <104881513+TheMimikyu@users.noreply.github.com> Date: Thu, 2 Jan 2025 15:07:30 +0530 Subject: [PATCH 161/229] feat: add model `nvidia/llama-3.2-nv-rerankqa-1b-v2` to `_MODEL_ENDPOINT_MAP` (#1260) * fix: add `nvidia/llama-3.2-nv-rerankqa-1b-v2` endpoint to `_MODEL_ENDPOINT_MAP` * chore: lint --- .../haystack_integrations/components/rankers/nvidia/ranker.py | 1 + 1 file changed, 1 insertion(+) diff --git a/integrations/nvidia/src/haystack_integrations/components/rankers/nvidia/ranker.py b/integrations/nvidia/src/haystack_integrations/components/rankers/nvidia/ranker.py index 66203a490..ec1af7ab8 100644 --- a/integrations/nvidia/src/haystack_integrations/components/rankers/nvidia/ranker.py +++ b/integrations/nvidia/src/haystack_integrations/components/rankers/nvidia/ranker.py @@ -19,6 +19,7 @@ _MODEL_ENDPOINT_MAP = { "nvidia/nv-rerankqa-mistral-4b-v3": "https://ai.api.nvidia.com/v1/retrieval/nvidia/nv-rerankqa-mistral-4b-v3/reranking", "nvidia/llama-3.2-nv-rerankqa-1b-v1": "https://ai.api.nvidia.com/v1/retrieval/nvidia/llama-3_2-nv-rerankqa-1b-v1/reranking", + "nvidia/llama-3.2-nv-rerankqa-1b-v2": "https://ai.api.nvidia.com/v1/retrieval/nvidia/llama-3_2-nv-rerankqa-1b-v2/reranking", } From fed2743c28ae06f63fec8f9cf687ce0ee6605113 Mon Sep 17 00:00:00 2001 From: HaystackBot Date: Thu, 2 Jan 2025 09:44:26 +0000 Subject: [PATCH 162/229] Update the changelog --- integrations/nvidia/CHANGELOG.md | 53 +++++++++++++++++++++++++++++--- 1 file changed, 49 insertions(+), 4 deletions(-) diff --git a/integrations/nvidia/CHANGELOG.md b/integrations/nvidia/CHANGELOG.md index a536e431d..2523907d6 100644 --- a/integrations/nvidia/CHANGELOG.md +++ b/integrations/nvidia/CHANGELOG.md @@ -1,11 +1,24 @@ # Changelog +## [integrations/nvidia-v0.1.3] - 2025-01-02 + +### 🚀 Features + +- Improvements to NvidiaRanker and adding user input timeout (#1193) +- Add model `nvidia/llama-3.2-nv-rerankqa-1b-v2` to `_MODEL_ENDPOINT_MAP` (#1260) + +### 🧹 Chores + +- Fix linting/isort (#1215) + + ## [integrations/nvidia-v0.1.1] - 2024-11-14 ### 🐛 Bug Fixes - Fixes to NvidiaRanker (#1191) + ## [integrations/nvidia-v0.1.0] - 2024-11-13 ### 🚀 Features @@ -31,16 +44,25 @@ - Do not retry tests in `hatch run test` command (#954) -### ⚙️ Miscellaneous Tasks +### ⚙️ CI - Retry tests to reduce flakyness (#836) +- Adopt uv as installer (#1142) + +### 🧹 Chores + - Update ruff invocation to include check parameter (#853) - Update ruff linting scripts and settings (#1105) -- Adopt uv as installer (#1142) -### Docs +### 🌀 Miscellaneous +- Fix: make hosted nim default (#734) +- Fix: align tests and docs on NVIDIA_API_KEY (instead of NVIDIA_CATALOG_API_KEY) (#731) +- Ci: install `pytest-rerunfailures` where needed; add retry config to `test-cov` script (#845) +- Raise warning for base_url ../embeddings .../completions .../rankings (#922) - Update NvidiaGenerator docstrings (#966) +- Add default model for NVIDIA HayStack local NIM endpoints (#915) +- Feat: add nvidia/llama-3.2-nv-rerankqa-1b-v1 to set of known ranking models (#1183) ## [integrations/nvidia-v0.0.3] - 2024-05-22 @@ -48,16 +70,30 @@ - Update docstrings of Nvidia integrations (#599) -### ⚙️ Miscellaneous Tasks +### ⚙️ CI - Add generate docs to Nvidia workflow (#603) +### 🌀 Miscellaneous + +- Remove references to Python 3.7 (#601) +- Chore: add license classifiers (#680) +- Chore: change the pydoc renderer class (#718) +- Update Nvidia integration to support new endpoints (#701) +- Docs: add missing api references (#728) +- Update _nim_backend.py (#744) + ## [integrations/nvidia-v0.0.2] - 2024-03-18 ### 📚 Documentation - Disable-class-def (#556) +### 🌀 Miscellaneous + +- Make tests show coverage (#566) +- Add NIM backend support (#597) + ## [integrations/nvidia-v0.0.1] - 2024-03-07 ### 🚀 Features @@ -68,6 +104,15 @@ - `nvidia-haystack`- Handle non-strict env var secrets correctly (#543) +### 🌀 Miscellaneous + +- Add `NvidiaGenerator` (#557) +- Add missing import in NvidiaGenerator docstring (#559) + ## [integrations/nvidia-v0.0.0] - 2024-03-01 +### 🌀 Miscellaneous + +- Add Nvidia integration scaffold (#515) + From b2ca80024294c5e87b543835eae309a3636cfec5 Mon Sep 17 00:00:00 2001 From: Stefano Fiorucci Date: Thu, 2 Jan 2025 11:04:35 +0100 Subject: [PATCH 163/229] update to chroma 0.6.0 (#1270) --- .github/workflows/chroma.yml | 2 +- integrations/chroma/pyproject.toml | 6 ++---- .../document_stores/chroma/document_store.py | 2 +- 3 files changed, 4 insertions(+), 6 deletions(-) diff --git a/.github/workflows/chroma.yml b/.github/workflows/chroma.yml index 323a8a0b3..93008f863 100644 --- a/.github/workflows/chroma.yml +++ b/.github/workflows/chroma.yml @@ -30,7 +30,7 @@ jobs: fail-fast: false matrix: os: [ubuntu-latest, windows-latest, macos-latest] - python-version: ["3.8", "3.9", "3.10"] + python-version: ["3.9", "3.10", "3.11"] steps: - name: Support longpaths diff --git a/integrations/chroma/pyproject.toml b/integrations/chroma/pyproject.toml index 40bc9a2b3..589b75f78 100644 --- a/integrations/chroma/pyproject.toml +++ b/integrations/chroma/pyproject.toml @@ -7,7 +7,7 @@ name = "chroma-haystack" dynamic = ["version"] description = '' readme = "README.md" -requires-python = ">=3.8" +requires-python = ">=3.9" license = "Apache-2.0" keywords = [] authors = [{ name = "John Doe", email = "jd@example.com" }] @@ -15,7 +15,6 @@ classifiers = [ "License :: OSI Approved :: Apache Software License", "Development Status :: 4 - Beta", "Programming Language :: Python", - "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", @@ -24,8 +23,7 @@ classifiers = [ ] dependencies = [ "haystack-ai", - "chromadb>=0.5.17", - "typing_extensions>=4.8.0" + "chromadb>=0.6.0", ] [project.urls] diff --git a/integrations/chroma/src/haystack_integrations/document_stores/chroma/document_store.py b/integrations/chroma/src/haystack_integrations/document_stores/chroma/document_store.py index 439e4b144..5dee9c889 100644 --- a/integrations/chroma/src/haystack_integrations/document_stores/chroma/document_store.py +++ b/integrations/chroma/src/haystack_integrations/document_stores/chroma/document_store.py @@ -116,7 +116,7 @@ def _ensure_initialized(self): if "hnsw:space" not in self._metadata: self._metadata["hnsw:space"] = self._distance_function - if self._collection_name in [c.name for c in client.list_collections()]: + if self._collection_name in client.list_collections(): self._collection = client.get_collection(self._collection_name, embedding_function=self._embedding_func) if self._metadata != self._collection.metadata: From 122a27b7885398b717fed01907193d5f658ccb36 Mon Sep 17 00:00:00 2001 From: HaystackBot Date: Thu, 2 Jan 2025 10:05:50 +0000 Subject: [PATCH 164/229] Update the changelog --- integrations/chroma/CHANGELOG.md | 99 ++++++++++++++++++++++++++++---- 1 file changed, 88 insertions(+), 11 deletions(-) diff --git a/integrations/chroma/CHANGELOG.md b/integrations/chroma/CHANGELOG.md index 591c0ec39..7cb32db67 100644 --- a/integrations/chroma/CHANGELOG.md +++ b/integrations/chroma/CHANGELOG.md @@ -1,5 +1,17 @@ # Changelog +## [integrations/chroma-v2.0.0] - 2025-01-02 + +### 🧹 Chores + +- Fix linting/isort (#1215) +- Chroma - pin `tokenizers` (#1223) + +### 🌀 Miscellaneous + +- Unpin tokenizers (#1233) +- Fix: updates for Chroma 0.6.0 (#1270) + ## [integrations/chroma-v1.0.0] - 2024-11-06 ### 🐛 Bug Fixes @@ -7,13 +19,14 @@ - Fixing Chroma tests due `chromadb` update behaviour change (#1148) - Adapt our implementation to breaking changes in Chroma 0.5.17 (#1165) -### ⚙️ Miscellaneous Tasks +### ⚙️ CI - Adopt uv as installer (#1142) + ## [integrations/chroma-v0.22.1] - 2024-09-30 -### Chroma +### 🌀 Miscellaneous - Empty filters should behave as no filters (#1117) @@ -26,32 +39,36 @@ ### 🐛 Bug Fixes -- Fix chroma linting; rm numpy (#1063) - -Co-authored-by: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com> +- Refactor: fix chroma linting; do not use numpy (#1063) - Filters in chroma integration (#1072) ### 🧪 Testing - Do not retry tests in `hatch run test` command (#954) -### ⚙️ Miscellaneous Tasks +### 🧹 Chores - Chroma - ruff update, don't ruff tests (#983) - Update ruff linting scripts and settings (#1105) +### 🌀 Miscellaneous + +- Chore: ChromaDocumentStore lint fix (#1065) + ## [integrations/chroma-v0.21.1] - 2024-07-17 ### 🐛 Bug Fixes - `ChromaDocumentStore` - discard `meta` items when the type of their value is not supported in Chroma (#907) + ## [integrations/chroma-v0.21.0] - 2024-07-16 ### 🚀 Features - Add metadata parameter to ChromaDocumentStore. (#906) + ## [integrations/chroma-v0.20.1] - 2024-07-15 ### 🚀 Features @@ -64,15 +81,32 @@ Co-authored-by: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com> - Allow search in ChromaDocumentStore without metadata (#863) - `Chroma` - Fallback to default filter policy when deserializing retrievers without the init parameter (#897) -### ⚙️ Miscellaneous Tasks +### ⚙️ CI - Retry tests to reduce flakyness (#836) + +### 🧹 Chores + - Update ruff invocation to include check parameter (#853) +### 🌀 Miscellaneous + +- Ci: install `pytest-rerunfailures` where needed; add retry config to `test-cov` script (#845) +- Chore: Minor retriever pydoc fix (#884) + ## [integrations/chroma-v0.18.0] - 2024-05-31 +### 🌀 Miscellaneous + +- Chore: pin chromadb>=0.5.0 (#777) + ## [integrations/chroma-v0.17.0] - 2024-05-10 +### 🌀 Miscellaneous + +- Chore: change the pydoc renderer class (#718) +- Implement filters for chromaQueryTextRetriever via existing haystack filters logic (#705) + ## [integrations/chroma-v0.16.0] - 2024-05-02 ### 📚 Documentation @@ -80,15 +114,27 @@ Co-authored-by: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com> - Small consistency improvements (#536) - Disable-class-def (#556) +### 🌀 Miscellaneous + +- Make tests show coverage (#566) +- Refactor tests (#574) +- Remove references to Python 3.7 (#601) +- Make Document Stores initially skip `SparseEmbedding` (#606) +- Pin databind-core (#619) +- Chore: add license classifiers (#680) +- Feature/bump chromadb dep to 0.5.0 (#700) + ## [integrations/chroma-v0.15.0] - 2024-03-01 +### 🌀 Miscellaneous + +- Release chroma on python 3.8 (#512) + ## [integrations/chroma-v0.14.0] - 2024-02-29 ### 🐛 Bug Fixes - Fix order of API docs (#447) - -This PR will also push the docs to Readme - Serialize the path to the local db (#506) ### 📚 Documentation @@ -96,35 +142,59 @@ This PR will also push the docs to Readme - Update category slug (#442) - Review chroma integration (#501) +### 🌀 Miscellaneous + +- Small improvements (#443) +- Fix: make write_documents compatible with the DocumentStore protocol (#505) + ## [integrations/chroma-v0.13.0] - 2024-02-13 +### 🌀 Miscellaneous + +- Chroma: rename retriever (#407) + ## [integrations/chroma-v0.12.0] - 2024-02-06 ### 🚀 Features - Generate API docs (#262) +### 🌀 Miscellaneous + +- Add typing_extensions pin to Chroma integration (#295) +- Allows filters and persistent document stores for Chroma (#342) + ## [integrations/chroma-v0.11.0] - 2024-01-18 ### 🐛 Bug Fixes - Chroma DocumentStore creation for pre-existing collection name (#157) +### 🌀 Miscellaneous + +- Mount chroma integration under `haystack_integrations.*` (#193) +- Remove ChromaSingleQueryRetriever (#240) + ## [integrations/chroma-v0.9.0] - 2023-12-20 ### 🐛 Bug Fixes -- Fix project urls (#96) +- Fix project URLs (#96) ### 🚜 Refactor - Use `hatch_vcs` to manage integrations versioning (#103) +### 🌀 Miscellaneous + +- Chore: pin chroma version (#104) +- Fix: update to the latest Document format (#127) + ## [integrations/chroma-v0.8.1] - 2023-12-05 ### 🐛 Bug Fixes -- Fix import and increase version (#77) +- Document Stores: fix protocol import (#77) ## [integrations/chroma-v0.8.0] - 2023-12-04 @@ -132,4 +202,11 @@ This PR will also push the docs to Readme - Fix license headers +### 🌀 Miscellaneous + +- Reorganize repository (#62) +- Update import paths (#64) +- Patch chroma filters tests (#67) +- Remove Document Store decorator (#76) + From 01bbadd594090782afea42c65dabccf950f219ed Mon Sep 17 00:00:00 2001 From: Stefano Fiorucci Date: Thu, 2 Jan 2025 15:06:07 +0100 Subject: [PATCH 165/229] Adapt Mistral to OpenAI refactoring (#1271) --- .../tests/test_mistral_chat_generator.py | 58 ++++++++++++------- 1 file changed, 37 insertions(+), 21 deletions(-) diff --git a/integrations/mistral/tests/test_mistral_chat_generator.py b/integrations/mistral/tests/test_mistral_chat_generator.py index 6277b9c36..be3dce497 100644 --- a/integrations/mistral/tests/test_mistral_chat_generator.py +++ b/integrations/mistral/tests/test_mistral_chat_generator.py @@ -80,18 +80,24 @@ def test_to_dict_default(self, monkeypatch): monkeypatch.setenv("MISTRAL_API_KEY", "test-api-key") component = MistralChatGenerator() data = component.to_dict() - assert data == { - "type": "haystack_integrations.components.generators.mistral.chat.chat_generator.MistralChatGenerator", - "init_parameters": { - "api_key": {"env_vars": ["MISTRAL_API_KEY"], "strict": True, "type": "env_var"}, - "model": "mistral-tiny", - "organization": None, - "streaming_callback": None, - "api_base_url": "https://api.mistral.ai/v1", - "generation_kwargs": {}, - }, + + assert ( + data["type"] + == "haystack_integrations.components.generators.mistral.chat.chat_generator.MistralChatGenerator" + ) + + expected_params = { + "api_key": {"env_vars": ["MISTRAL_API_KEY"], "strict": True, "type": "env_var"}, + "model": "mistral-tiny", + "organization": None, + "streaming_callback": None, + "api_base_url": "https://api.mistral.ai/v1", + "generation_kwargs": {}, } + for key, value in expected_params.items(): + assert data["init_parameters"][key] == value + def test_to_dict_with_parameters(self, monkeypatch): monkeypatch.setenv("ENV_VAR", "test-api-key") component = MistralChatGenerator( @@ -102,18 +108,23 @@ def test_to_dict_with_parameters(self, monkeypatch): generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"}, ) data = component.to_dict() - assert data == { - "type": "haystack_integrations.components.generators.mistral.chat.chat_generator.MistralChatGenerator", - "init_parameters": { - "api_key": {"env_vars": ["ENV_VAR"], "strict": True, "type": "env_var"}, - "model": "mistral-small", - "api_base_url": "test-base-url", - "organization": None, - "streaming_callback": "haystack.components.generators.utils.print_streaming_chunk", - "generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"}, - }, + + assert ( + data["type"] + == "haystack_integrations.components.generators.mistral.chat.chat_generator.MistralChatGenerator" + ) + + expected_params = { + "api_key": {"env_vars": ["ENV_VAR"], "strict": True, "type": "env_var"}, + "model": "mistral-small", + "api_base_url": "test-base-url", + "streaming_callback": "haystack.components.generators.utils.print_streaming_chunk", + "generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"}, } + for key, value in expected_params.items(): + assert data["init_parameters"][key] == value + def test_from_dict(self, monkeypatch): monkeypatch.setenv("MISTRAL_API_KEY", "fake-api-key") data = { @@ -187,7 +198,12 @@ def test_check_abnormal_completions(self, caplog): ] for m in messages: - component._check_finish_reason(m) + try: + # Haystack >= 2.9.0 + component._check_finish_reason(m.meta) + except AttributeError: + # Haystack < 2.9.0 + component._check_finish_reason(m) # check truncation warning message_template = ( From 081260c97b24fae21e548588761181905b99bb16 Mon Sep 17 00:00:00 2001 From: Stefano Fiorucci Date: Thu, 2 Jan 2025 15:21:05 +0100 Subject: [PATCH 166/229] langfuse: fix messages conversion to OpenAI format (#1272) --- .../tracing/langfuse/tracer.py | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/integrations/langfuse/src/haystack_integrations/tracing/langfuse/tracer.py b/integrations/langfuse/src/haystack_integrations/tracing/langfuse/tracer.py index 6af05633e..b6bc96860 100644 --- a/integrations/langfuse/src/haystack_integrations/tracing/langfuse/tracer.py +++ b/integrations/langfuse/src/haystack_integrations/tracing/langfuse/tracer.py @@ -5,14 +5,17 @@ from typing import Any, Dict, Iterator, List, Optional, Union from haystack import logging -from haystack.components.generators.openai_utils import _convert_message_to_openai_format from haystack.dataclasses import ChatMessage +from haystack.lazy_imports import LazyImport from haystack.tracing import Span, Tracer from haystack.tracing import tracer as proxy_tracer from haystack.tracing import utils as tracing_utils import langfuse +with LazyImport("") as openai_utils_import: + from haystack.components.generators.openai_utils import _convert_message_to_openai_format + logger = logging.getLogger(__name__) HAYSTACK_LANGFUSE_ENFORCE_FLUSH_ENV_VAR = "HAYSTACK_LANGFUSE_ENFORCE_FLUSH" @@ -83,14 +86,24 @@ def set_content_tag(self, key: str, value: Any) -> None: return if key.endswith(".input"): if "messages" in value: - messages = [_convert_message_to_openai_format(m) for m in value["messages"]] + if openai_utils_import.is_successful(): + # Haystack < 2.9.0 + messages = [_convert_message_to_openai_format(m) for m in value["messages"]] + else: + # Haystack >= 2.9.0 + messages = [m.to_openai_dict_format() for m in value["messages"]] self._span.update(input=messages) else: self._span.update(input=value) elif key.endswith(".output"): if "replies" in value: if all(isinstance(r, ChatMessage) for r in value["replies"]): - replies = [_convert_message_to_openai_format(m) for m in value["replies"]] + if openai_utils_import.is_successful(): + # Haystack < 2.9.0 + replies = [_convert_message_to_openai_format(m) for m in value["replies"]] + else: + # Haystack >= 2.9.0 + replies = [m.to_openai_dict_format() for m in value["replies"]] else: replies = value["replies"] self._span.update(output=replies) From 67e4ab6d4dd8fcbed33280688073c177f6f7f01d Mon Sep 17 00:00:00 2001 From: HaystackBot Date: Thu, 2 Jan 2025 14:25:49 +0000 Subject: [PATCH 167/229] Update the changelog --- integrations/langfuse/CHANGELOG.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/integrations/langfuse/CHANGELOG.md b/integrations/langfuse/CHANGELOG.md index 0ecb42b48..a3148496d 100644 --- a/integrations/langfuse/CHANGELOG.md +++ b/integrations/langfuse/CHANGELOG.md @@ -1,6 +1,6 @@ # Changelog -## [unreleased] +## [integrations/langfuse-v0.6.2] - 2025-01-02 ### 🚀 Features @@ -13,6 +13,7 @@ ### 🌀 Miscellaneous - Chore: Fix tracing_context_var lint errors (#1220) +- Fix messages conversion to OpenAI format (#1272) ## [integrations/langfuse-v0.6.0] - 2024-11-18 From f3c7cd4148a67f4b6b89d9d0eb7cb56cbd3eff82 Mon Sep 17 00:00:00 2001 From: Daniel Kleine <53251018+d-kleine@users.noreply.github.com> Date: Wed, 8 Jan 2025 11:55:21 +0100 Subject: [PATCH 168/229] added mistral-4b v1 reranker (#1278) --- .../haystack_integrations/components/rankers/nvidia/ranker.py | 1 + 1 file changed, 1 insertion(+) diff --git a/integrations/nvidia/src/haystack_integrations/components/rankers/nvidia/ranker.py b/integrations/nvidia/src/haystack_integrations/components/rankers/nvidia/ranker.py index ec1af7ab8..67215d753 100644 --- a/integrations/nvidia/src/haystack_integrations/components/rankers/nvidia/ranker.py +++ b/integrations/nvidia/src/haystack_integrations/components/rankers/nvidia/ranker.py @@ -17,6 +17,7 @@ _DEFAULT_MODEL = "nvidia/nv-rerankqa-mistral-4b-v3" _MODEL_ENDPOINT_MAP = { + "nv-rerank-qa-mistral-4b:1": "https://ai.api.nvidia.com/v1/retrieval/nvidia/reranking", "nvidia/nv-rerankqa-mistral-4b-v3": "https://ai.api.nvidia.com/v1/retrieval/nvidia/nv-rerankqa-mistral-4b-v3/reranking", "nvidia/llama-3.2-nv-rerankqa-1b-v1": "https://ai.api.nvidia.com/v1/retrieval/nvidia/llama-3_2-nv-rerankqa-1b-v1/reranking", "nvidia/llama-3.2-nv-rerankqa-1b-v2": "https://ai.api.nvidia.com/v1/retrieval/nvidia/llama-3_2-nv-rerankqa-1b-v2/reranking", From f5c3aee5478b310d9fee785ba2ade196753fc78a Mon Sep 17 00:00:00 2001 From: HaystackBot Date: Wed, 8 Jan 2025 10:57:56 +0000 Subject: [PATCH 169/229] Update the changelog --- integrations/nvidia/CHANGELOG.md | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/integrations/nvidia/CHANGELOG.md b/integrations/nvidia/CHANGELOG.md index 2523907d6..a3de61119 100644 --- a/integrations/nvidia/CHANGELOG.md +++ b/integrations/nvidia/CHANGELOG.md @@ -1,5 +1,11 @@ # Changelog +## [integrations/nvidia-v0.1.4] - 2025-01-08 + +### 🌀 Miscellaneous + +- Feat: add nv-rerank-qa-mistral-4b:1 reranker (#1278) + ## [integrations/nvidia-v0.1.3] - 2025-01-02 ### 🚀 Features From 97a2cdaca52af584ecb3d2282b78be40dc99e039 Mon Sep 17 00:00:00 2001 From: Stefano Fiorucci Date: Thu, 9 Jan 2025 18:55:24 +0100 Subject: [PATCH 170/229] remove tests involving serialization of lambdas (#1281) --- .../anthropic/tests/test_chat_generator.py | 19 ------------------ .../tests/test_cohere_chat_generator.py | 20 ------------------- 2 files changed, 39 deletions(-) diff --git a/integrations/anthropic/tests/test_chat_generator.py b/integrations/anthropic/tests/test_chat_generator.py index 36622ecd9..d46ee624d 100644 --- a/integrations/anthropic/tests/test_chat_generator.py +++ b/integrations/anthropic/tests/test_chat_generator.py @@ -81,25 +81,6 @@ def test_to_dict_with_parameters(self, monkeypatch): }, } - def test_to_dict_with_lambda_streaming_callback(self, monkeypatch): - monkeypatch.setenv("ANTHROPIC_API_KEY", "test-api-key") - component = AnthropicChatGenerator( - model="claude-3-5-sonnet-20240620", - streaming_callback=lambda x: x, - generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"}, - ) - data = component.to_dict() - assert data == { - "type": "haystack_integrations.components.generators.anthropic.chat.chat_generator.AnthropicChatGenerator", - "init_parameters": { - "api_key": {"env_vars": ["ANTHROPIC_API_KEY"], "strict": True, "type": "env_var"}, - "model": "claude-3-5-sonnet-20240620", - "streaming_callback": "tests.test_chat_generator.", - "generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"}, - "ignore_tools_thinking_messages": True, - }, - } - def test_from_dict(self, monkeypatch): monkeypatch.setenv("ANTHROPIC_API_KEY", "fake-api-key") data = { diff --git a/integrations/cohere/tests/test_cohere_chat_generator.py b/integrations/cohere/tests/test_cohere_chat_generator.py index 4aaa2da2b..1a4774a40 100644 --- a/integrations/cohere/tests/test_cohere_chat_generator.py +++ b/integrations/cohere/tests/test_cohere_chat_generator.py @@ -98,26 +98,6 @@ def test_to_dict_with_parameters(self, monkeypatch): }, } - def test_to_dict_with_lambda_streaming_callback(self, monkeypatch): - monkeypatch.setenv("COHERE_API_KEY", "test-api-key") - component = CohereChatGenerator( - model="command-r", - streaming_callback=lambda x: x, - api_base_url="test-base-url", - generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"}, - ) - data = component.to_dict() - assert data == { - "type": "haystack_integrations.components.generators.cohere.chat.chat_generator.CohereChatGenerator", - "init_parameters": { - "model": "command-r", - "api_base_url": "test-base-url", - "api_key": {"env_vars": ["COHERE_API_KEY", "CO_API_KEY"], "strict": True, "type": "env_var"}, - "streaming_callback": "tests.test_cohere_chat_generator.", - "generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"}, - }, - } - def test_from_dict(self, monkeypatch): monkeypatch.setenv("COHERE_API_KEY", "fake-api-key") monkeypatch.setenv("CO_API_KEY", "fake-api-key") From 20c943753e8704642f19da2771b29c9a86788e4c Mon Sep 17 00:00:00 2001 From: Stefano Fiorucci Date: Fri, 10 Jan 2025 11:01:17 +0100 Subject: [PATCH 171/229] remove more tests involving serialization of lambdas (#1285) --- .../anthropic/tests/test_generator.py | 19 --------------- .../tests/test_vertex_chat_generator.py | 24 ------------------- .../cohere/tests/test_cohere_generator.py | 21 ---------------- 3 files changed, 64 deletions(-) diff --git a/integrations/anthropic/tests/test_generator.py b/integrations/anthropic/tests/test_generator.py index 029cd3920..918b6f775 100644 --- a/integrations/anthropic/tests/test_generator.py +++ b/integrations/anthropic/tests/test_generator.py @@ -70,25 +70,6 @@ def test_to_dict_with_parameters(self, monkeypatch): }, } - def test_to_dict_with_lambda_streaming_callback(self, monkeypatch): - monkeypatch.setenv("ANTHROPIC_API_KEY", "test-api-key") - component = AnthropicGenerator( - model="claude-3-sonnet-20240229", - streaming_callback=lambda x: x, - generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"}, - ) - data = component.to_dict() - assert data == { - "type": "haystack_integrations.components.generators.anthropic.generator.AnthropicGenerator", - "init_parameters": { - "api_key": {"env_vars": ["ANTHROPIC_API_KEY"], "strict": True, "type": "env_var"}, - "model": "claude-3-sonnet-20240229", - "streaming_callback": "tests.test_generator.", - "system_prompt": None, - "generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"}, - }, - } - def test_from_dict(self, monkeypatch): monkeypatch.setenv("ANTHROPIC_API_KEY", "fake-api-key") data = { diff --git a/integrations/anthropic/tests/test_vertex_chat_generator.py b/integrations/anthropic/tests/test_vertex_chat_generator.py index fefb508ac..6c3a30d89 100644 --- a/integrations/anthropic/tests/test_vertex_chat_generator.py +++ b/integrations/anthropic/tests/test_vertex_chat_generator.py @@ -83,30 +83,6 @@ def test_to_dict_with_parameters(self): }, } - def test_to_dict_with_lambda_streaming_callback(self): - component = AnthropicVertexChatGenerator( - region="us-central1", - project_id="test-project-id", - model="claude-3-5-sonnet@20240620", - streaming_callback=lambda x: x, - generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"}, - ) - data = component.to_dict() - assert data == { - "type": ( - "haystack_integrations.components.generators." - "anthropic.chat.vertex_chat_generator.AnthropicVertexChatGenerator" - ), - "init_parameters": { - "region": "us-central1", - "project_id": "test-project-id", - "model": "claude-3-5-sonnet@20240620", - "streaming_callback": "tests.test_vertex_chat_generator.", - "generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"}, - "ignore_tools_thinking_messages": True, - }, - } - def test_from_dict(self): data = { "type": ( diff --git a/integrations/cohere/tests/test_cohere_generator.py b/integrations/cohere/tests/test_cohere_generator.py index 60ee6ac93..fffe872f5 100644 --- a/integrations/cohere/tests/test_cohere_generator.py +++ b/integrations/cohere/tests/test_cohere_generator.py @@ -78,27 +78,6 @@ def test_to_dict_with_parameters(self, monkeypatch): }, } - def test_to_dict_with_lambda_streaming_callback(self, monkeypatch): - monkeypatch.setenv("COHERE_API_KEY", "test-api-key") - component = CohereGenerator( - model="command-r", - max_tokens=10, - some_test_param="test-params", - streaming_callback=lambda x: x, - api_base_url="test-base-url", - ) - data = component.to_dict() - assert data == { - "type": "haystack_integrations.components.generators.cohere.generator.CohereGenerator", - "init_parameters": { - "model": "command-r", - "streaming_callback": "tests.test_cohere_generator.", - "api_base_url": "test-base-url", - "api_key": {"type": "env_var", "env_vars": ["COHERE_API_KEY", "CO_API_KEY"], "strict": True}, - "generation_kwargs": {}, - }, - } - def test_from_dict(self, monkeypatch): monkeypatch.setenv("COHERE_API_KEY", "fake-api-key") monkeypatch.setenv("CO_API_KEY", "fake-api-key") From 20011ec2163401d04beac111da6b93f1a8839159 Mon Sep 17 00:00:00 2001 From: Martin Barton <32176108+mabartcz@users.noreply.github.com> Date: Fri, 10 Jan 2025 18:21:38 +0100 Subject: [PATCH 172/229] fix: PgvectorDocumentStore - use appropriate schema name if dropping index (#1277) * fix: Add schema name if dropping index in pgvector store * fix: Remove check for deletion in src * new integration test --------- Co-authored-by: anakin87 --- integrations/pgvector/README.md | 2 +- .../pgvector/document_store.py | 4 +- .../pgvector/tests/test_document_store.py | 45 +++++++++++++++++++ 3 files changed, 49 insertions(+), 2 deletions(-) diff --git a/integrations/pgvector/README.md b/integrations/pgvector/README.md index a2d325c54..4859762f9 100644 --- a/integrations/pgvector/README.md +++ b/integrations/pgvector/README.md @@ -22,7 +22,7 @@ pip install pgvector-haystack Ensure that you have a PostgreSQL running with the `pgvector` extension. For a quick setup using Docker, run: ``` -docker run -d -p 5432:5432 -e POSTGRES_USER=postgres -e POSTGRES_PASSWORD=postgres -e POSTGRES_DB=postgres ankane/pgvector +docker run -d -p 5432:5432 -e POSTGRES_USER=postgres -e POSTGRES_PASSWORD=postgres -e POSTGRES_DB=postgres pgvector/pgvector:pg17 ``` then run the tests: diff --git a/integrations/pgvector/src/haystack_integrations/document_stores/pgvector/document_store.py b/integrations/pgvector/src/haystack_integrations/document_stores/pgvector/document_store.py index 87655a5ec..648ae88af 100644 --- a/integrations/pgvector/src/haystack_integrations/document_stores/pgvector/document_store.py +++ b/integrations/pgvector/src/haystack_integrations/document_stores/pgvector/document_store.py @@ -389,7 +389,9 @@ def _handle_hnsw(self): ) return - sql_drop_index = SQL("DROP INDEX IF EXISTS {index_name}").format(index_name=Identifier(self.hnsw_index_name)) + sql_drop_index = SQL("DROP INDEX IF EXISTS {schema_name}.{index_name}").format( + schema_name=Identifier(self.schema_name), index_name=Identifier(self.hnsw_index_name) + ) self._execute_sql(sql_drop_index, error_msg="Could not drop HNSW index") self._create_hnsw_index() diff --git a/integrations/pgvector/tests/test_document_store.py b/integrations/pgvector/tests/test_document_store.py index baa921137..a331a990e 100644 --- a/integrations/pgvector/tests/test_document_store.py +++ b/integrations/pgvector/tests/test_document_store.py @@ -5,6 +5,7 @@ from unittest.mock import patch import numpy as np +import psycopg import pytest from haystack.dataclasses.document import ByteStream, Document from haystack.document_stores.errors import DuplicateDocumentError @@ -259,3 +260,47 @@ def test_from_pg_to_haystack_documents(): assert haystack_docs[2].meta == {"meta_key": "meta_value"} assert haystack_docs[2].embedding == [0.7, 0.8, 0.9] assert haystack_docs[2].score is None + + +@pytest.mark.integration +def test_hnsw_index_recreation(): + def get_index_oid(document_store, schema_name, index_name): + sql_get_index_oid = """ + SELECT c.oid + FROM pg_class c + JOIN pg_namespace n ON n.oid = c.relnamespace + WHERE c.relkind = 'i' + AND n.nspname = %s + AND c.relname = %s; + """ + return document_store.cursor.execute(sql_get_index_oid, (schema_name, index_name)).fetchone()[0] + + # create a new schema + connection_string = "postgresql://postgres:postgres@localhost:5432/postgres" + schema_name = "test_schema" + with psycopg.connect(connection_string, autocommit=True) as conn: + conn.execute(f"CREATE SCHEMA IF NOT EXISTS {schema_name}") + + # create a first document store and trigger the creation of the hnsw index + params = { + "connection_string": Secret.from_token(connection_string), + "schema_name": schema_name, + "table_name": "haystack_test_hnsw_index_recreation", + "search_strategy": "hnsw", + } + ds1 = PgvectorDocumentStore(**params) + ds1._initialize_table() + + # get the hnsw index oid + hnws_index_name = "haystack_hnsw_index" + first_oid = get_index_oid(ds1, ds1.schema_name, hnws_index_name) + + # create second document store with recreation enabled + ds2 = PgvectorDocumentStore(**params, hnsw_recreate_index_if_exists=True) + ds2._initialize_table() + + # get the index oid + second_oid = get_index_oid(ds2, ds2.schema_name, hnws_index_name) + + # verify that oids differ + assert second_oid != first_oid, "Index was not recreated (OID remained the same)" From e2a008aac4aa37becf325f51bc521ee326b6cb4e Mon Sep 17 00:00:00 2001 From: HaystackBot Date: Fri, 10 Jan 2025 17:24:32 +0000 Subject: [PATCH 173/229] Update the changelog --- integrations/pgvector/CHANGELOG.md | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/integrations/pgvector/CHANGELOG.md b/integrations/pgvector/CHANGELOG.md index f3821f1d3..c68df2da4 100644 --- a/integrations/pgvector/CHANGELOG.md +++ b/integrations/pgvector/CHANGELOG.md @@ -1,5 +1,12 @@ # Changelog +## [integrations/pgvector-v1.2.1] - 2025-01-10 + +### 🐛 Bug Fixes + +- PgvectorDocumentStore - use appropriate schema name if dropping index (#1277) + + ## [integrations/pgvector-v1.2.0] - 2024-11-22 ### 🚀 Features From bed73fc8e1e8db7b12e320f8bbc6e6f6dbb58df2 Mon Sep 17 00:00:00 2001 From: Stefano Fiorucci Date: Wed, 15 Jan 2025 08:33:38 +0100 Subject: [PATCH 174/229] fix: Cohere - fix chat message creation (#1289) --- .../components/generators/cohere/chat/chat_generator.py | 4 ++-- integrations/cohere/tests/test_cohere_chat_generator.py | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/integrations/cohere/src/haystack_integrations/components/generators/cohere/chat/chat_generator.py b/integrations/cohere/src/haystack_integrations/components/generators/cohere/chat/chat_generator.py index 3fae30baa..33e7c98f6 100644 --- a/integrations/cohere/src/haystack_integrations/components/generators/cohere/chat/chat_generator.py +++ b/integrations/cohere/src/haystack_integrations/components/generators/cohere/chat/chat_generator.py @@ -172,7 +172,7 @@ def run(self, messages: List[ChatMessage], generation_kwargs: Optional[Dict[str, response_text += event.text elif event.event_type == "stream-end": finish_response = event.response - chat_message = ChatMessage.from_assistant(content=response_text) + chat_message = ChatMessage.from_assistant(response_text) if finish_response and finish_response.meta: if finish_response.meta.billed_units: @@ -219,7 +219,7 @@ def _build_message(self, cohere_response): # TODO revisit to see if we need to handle multiple tool calls message = ChatMessage.from_assistant(cohere_response.tool_calls[0].json()) elif cohere_response.text: - message = ChatMessage.from_assistant(content=cohere_response.text) + message = ChatMessage.from_assistant(cohere_response.text) message.meta.update( { "model": self.model, diff --git a/integrations/cohere/tests/test_cohere_chat_generator.py b/integrations/cohere/tests/test_cohere_chat_generator.py index 1a4774a40..05a18f074 100644 --- a/integrations/cohere/tests/test_cohere_chat_generator.py +++ b/integrations/cohere/tests/test_cohere_chat_generator.py @@ -144,7 +144,7 @@ def test_message_to_dict(self, chat_messages): ) @pytest.mark.integration def test_live_run(self): - chat_messages = [ChatMessage.from_user(content="What's the capital of France")] + chat_messages = [ChatMessage.from_user("What's the capital of France")] component = CohereChatGenerator(generation_kwargs={"temperature": 0.8}) results = component.run(chat_messages) assert len(results["replies"]) == 1 @@ -181,7 +181,7 @@ def __call__(self, chunk: StreamingChunk) -> None: callback = Callback() component = CohereChatGenerator(streaming_callback=callback) - results = component.run([ChatMessage.from_user(content="What's the capital of France? answer in a word")]) + results = component.run([ChatMessage.from_user("What's the capital of France? answer in a word")]) assert len(results["replies"]) == 1 message: ChatMessage = results["replies"][0] @@ -202,7 +202,7 @@ def __call__(self, chunk: StreamingChunk) -> None: ) @pytest.mark.integration def test_live_run_with_connector(self): - chat_messages = [ChatMessage.from_user(content="What's the capital of France")] + chat_messages = [ChatMessage.from_user("What's the capital of France")] component = CohereChatGenerator(generation_kwargs={"temperature": 0.8}) results = component.run(chat_messages, generation_kwargs={"connectors": [{"id": "web-search"}]}) assert len(results["replies"]) == 1 @@ -227,7 +227,7 @@ def __call__(self, chunk: StreamingChunk) -> None: self.responses += chunk.content if chunk.content else "" callback = Callback() - chat_messages = [ChatMessage.from_user(content="What's the capital of France? answer in a word")] + chat_messages = [ChatMessage.from_user("What's the capital of France? answer in a word")] component = CohereChatGenerator(streaming_callback=callback) results = component.run(chat_messages, generation_kwargs={"connectors": [{"id": "web-search"}]}) From 717648eff7f8f8dea7538421ae444817e119f752 Mon Sep 17 00:00:00 2001 From: HaystackBot Date: Wed, 15 Jan 2025 07:35:03 +0000 Subject: [PATCH 175/229] Update the changelog --- integrations/cohere/CHANGELOG.md | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/integrations/cohere/CHANGELOG.md b/integrations/cohere/CHANGELOG.md index 1d98408e9..1300a3efa 100644 --- a/integrations/cohere/CHANGELOG.md +++ b/integrations/cohere/CHANGELOG.md @@ -1,5 +1,17 @@ # Changelog +## [integrations/cohere-v2.0.2] - 2025-01-15 + +### 🐛 Bug Fixes + +- Make GoogleAI Chat Generator compatible with new `ChatMessage`; small fixes to Cohere tests (#1253) +- Cohere - fix chat message creation (#1289) + +### 🌀 Miscellaneous + +- Test: remove tests involving serialization of lambdas (#1281) +- Test: remove more tests involving serialization of lambdas (#1285) + ## [integrations/cohere-v2.0.1] - 2024-12-09 ### ⚙️ CI From 6641c2d1d67616cedac84cc82f9643626f0c759d Mon Sep 17 00:00:00 2001 From: Stefano Fiorucci Date: Wed, 15 Jan 2025 15:35:19 +0100 Subject: [PATCH 176/229] chore: Mistral - pin haystack-ai>=2.9.0 and simplify test (#1293) --- integrations/mistral/pyproject.toml | 2 +- integrations/mistral/tests/test_mistral_chat_generator.py | 7 +------ 2 files changed, 2 insertions(+), 7 deletions(-) diff --git a/integrations/mistral/pyproject.toml b/integrations/mistral/pyproject.toml index 06d02c0aa..b2694b729 100644 --- a/integrations/mistral/pyproject.toml +++ b/integrations/mistral/pyproject.toml @@ -22,7 +22,7 @@ classifiers = [ "Programming Language :: Python :: Implementation :: CPython", "Programming Language :: Python :: Implementation :: PyPy", ] -dependencies = ["haystack-ai"] +dependencies = ["haystack-ai>=2.9.0"] [project.urls] Documentation = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/mistral#readme" diff --git a/integrations/mistral/tests/test_mistral_chat_generator.py b/integrations/mistral/tests/test_mistral_chat_generator.py index be3dce497..185497591 100644 --- a/integrations/mistral/tests/test_mistral_chat_generator.py +++ b/integrations/mistral/tests/test_mistral_chat_generator.py @@ -198,12 +198,7 @@ def test_check_abnormal_completions(self, caplog): ] for m in messages: - try: - # Haystack >= 2.9.0 - component._check_finish_reason(m.meta) - except AttributeError: - # Haystack < 2.9.0 - component._check_finish_reason(m) + component._check_finish_reason(m.meta) # check truncation warning message_template = ( From 0f58003bab00dc4ffe410ac258e5a3a88b533255 Mon Sep 17 00:00:00 2001 From: Stefano Fiorucci Date: Wed, 15 Jan 2025 17:14:42 +0100 Subject: [PATCH 177/229] pin haystack-ai>=2.9.0 and simplify (#1292) --- integrations/langfuse/pyproject.toml | 2 +- .../tracing/langfuse/tracer.py | 18 ++---------------- 2 files changed, 3 insertions(+), 17 deletions(-) diff --git a/integrations/langfuse/pyproject.toml b/integrations/langfuse/pyproject.toml index 44397b572..92ebc7b8f 100644 --- a/integrations/langfuse/pyproject.toml +++ b/integrations/langfuse/pyproject.toml @@ -22,7 +22,7 @@ classifiers = [ "Programming Language :: Python :: Implementation :: CPython", "Programming Language :: Python :: Implementation :: PyPy", ] -dependencies = ["haystack-ai>=2.1.0", "langfuse"] +dependencies = ["haystack-ai>=2.9.0", "langfuse"] [project.urls] Documentation = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/langfuse#readme" diff --git a/integrations/langfuse/src/haystack_integrations/tracing/langfuse/tracer.py b/integrations/langfuse/src/haystack_integrations/tracing/langfuse/tracer.py index b6bc96860..1b7187f30 100644 --- a/integrations/langfuse/src/haystack_integrations/tracing/langfuse/tracer.py +++ b/integrations/langfuse/src/haystack_integrations/tracing/langfuse/tracer.py @@ -6,16 +6,12 @@ from haystack import logging from haystack.dataclasses import ChatMessage -from haystack.lazy_imports import LazyImport from haystack.tracing import Span, Tracer from haystack.tracing import tracer as proxy_tracer from haystack.tracing import utils as tracing_utils import langfuse -with LazyImport("") as openai_utils_import: - from haystack.components.generators.openai_utils import _convert_message_to_openai_format - logger = logging.getLogger(__name__) HAYSTACK_LANGFUSE_ENFORCE_FLUSH_ENV_VAR = "HAYSTACK_LANGFUSE_ENFORCE_FLUSH" @@ -86,24 +82,14 @@ def set_content_tag(self, key: str, value: Any) -> None: return if key.endswith(".input"): if "messages" in value: - if openai_utils_import.is_successful(): - # Haystack < 2.9.0 - messages = [_convert_message_to_openai_format(m) for m in value["messages"]] - else: - # Haystack >= 2.9.0 - messages = [m.to_openai_dict_format() for m in value["messages"]] + messages = [m.to_openai_dict_format() for m in value["messages"]] self._span.update(input=messages) else: self._span.update(input=value) elif key.endswith(".output"): if "replies" in value: if all(isinstance(r, ChatMessage) for r in value["replies"]): - if openai_utils_import.is_successful(): - # Haystack < 2.9.0 - replies = [_convert_message_to_openai_format(m) for m in value["replies"]] - else: - # Haystack >= 2.9.0 - replies = [m.to_openai_dict_format() for m in value["replies"]] + replies = [m.to_openai_dict_format() for m in value["replies"]] else: replies = value["replies"] self._span.update(output=replies) From 9dfa9bc096db2619e7bf7349e429b521862af1b1 Mon Sep 17 00:00:00 2001 From: HaystackBot Date: Wed, 15 Jan 2025 16:17:06 +0000 Subject: [PATCH 178/229] Update the changelog --- integrations/langfuse/CHANGELOG.md | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/integrations/langfuse/CHANGELOG.md b/integrations/langfuse/CHANGELOG.md index a3148496d..9f857966b 100644 --- a/integrations/langfuse/CHANGELOG.md +++ b/integrations/langfuse/CHANGELOG.md @@ -1,5 +1,11 @@ # Changelog +## [integrations/langfuse-v0.6.3] - 2025-01-15 + +### 🌀 Miscellaneous + +- Chore: Langfuse - pin `haystack-ai>=2.9.0` and simplify message conversion (#1292) + ## [integrations/langfuse-v0.6.2] - 2025-01-02 ### 🚀 Features From e42032177a448a94529be9662f76f138322acfa9 Mon Sep 17 00:00:00 2001 From: Stefano Fiorucci Date: Thu, 16 Jan 2025 11:56:03 +0100 Subject: [PATCH 179/229] feat: Ollama - add support for tools (#1294) * progress * Ollama: add support for tools * make generator tests run with llama * PR feedback + simplification --- .github/workflows/ollama.yml | 2 +- integrations/ollama/pyproject.toml | 1 + .../generators/ollama/chat/chat_generator.py | 143 +++++-- .../ollama/tests/test_chat_generator.py | 367 ++++++++++++++++-- integrations/ollama/tests/test_generator.py | 4 +- 5 files changed, 436 insertions(+), 81 deletions(-) diff --git a/.github/workflows/ollama.yml b/.github/workflows/ollama.yml index 8be6bf263..8b53f4a55 100644 --- a/.github/workflows/ollama.yml +++ b/.github/workflows/ollama.yml @@ -21,7 +21,7 @@ concurrency: env: PYTHONUNBUFFERED: "1" FORCE_COLOR: "1" - LLM_FOR_TESTS: "orca-mini" + LLM_FOR_TESTS: "llama3.2:3b" EMBEDDER_FOR_TESTS: "nomic-embed-text" jobs: diff --git a/integrations/ollama/pyproject.toml b/integrations/ollama/pyproject.toml index c9fc22f3d..65895e636 100644 --- a/integrations/ollama/pyproject.toml +++ b/integrations/ollama/pyproject.toml @@ -51,6 +51,7 @@ dependencies = [ "pytest", "pytest-rerunfailures", "haystack-pydoc-tools", + "jsonschema", # needed for Tool ] [tool.hatch.envs.default.scripts] test = "pytest {args:tests}" diff --git a/integrations/ollama/src/haystack_integrations/components/generators/ollama/chat/chat_generator.py b/integrations/ollama/src/haystack_integrations/components/generators/ollama/chat/chat_generator.py index daae552e5..c2112d3d6 100644 --- a/integrations/ollama/src/haystack_integrations/components/generators/ollama/chat/chat_generator.py +++ b/integrations/ollama/src/haystack_integrations/components/generators/ollama/chat/chat_generator.py @@ -1,12 +1,68 @@ from typing import Any, Callable, Dict, List, Optional, Union from haystack import component, default_from_dict, default_to_dict -from haystack.dataclasses import ChatMessage, StreamingChunk +from haystack.dataclasses import ChatMessage, StreamingChunk, ToolCall +from haystack.tools import Tool, _check_duplicate_tool_names, deserialize_tools_inplace from haystack.utils.callable_serialization import deserialize_callable, serialize_callable from ollama import ChatResponse, Client +def _convert_chatmessage_to_ollama_format(message: ChatMessage) -> Dict[str, Any]: + """ + Convert a ChatMessage to the format expected by Ollama Chat API. + """ + text_contents = message.texts + tool_calls = message.tool_calls + tool_call_results = message.tool_call_results + + if not text_contents and not tool_calls and not tool_call_results: + msg = "A `ChatMessage` must contain at least one `TextContent`, `ToolCall`, or `ToolCallResult`." + raise ValueError(msg) + elif len(text_contents) + len(tool_call_results) > 1: + msg = "A `ChatMessage` can only contain one `TextContent` or one `ToolCallResult`." + raise ValueError(msg) + + ollama_msg: Dict[str, Any] = {"role": message._role.value} + + if tool_call_results: + # Ollama does not provide a way to communicate errors in tool invocations, so we ignore the error field + ollama_msg["content"] = tool_call_results[0].result + return ollama_msg + + if text_contents: + ollama_msg["content"] = text_contents[0] + if tool_calls: + # Ollama does not support tool call id, so we ignore it + ollama_msg["tool_calls"] = [ + {"type": "function", "function": {"name": tc.tool_name, "arguments": tc.arguments}} for tc in tool_calls + ] + return ollama_msg + + +def _convert_ollama_response_to_chatmessage(ollama_response: "ChatResponse") -> ChatMessage: + """ + Converts the non-streaming response from the Ollama API to a ChatMessage with assistant role. + """ + response_dict = ollama_response.model_dump() + + ollama_message = response_dict["message"] + + text = ollama_message["content"] + + tool_calls = [] + if ollama_tool_calls := ollama_message.get("tool_calls"): + for ollama_tc in ollama_tool_calls: + tool_calls.append( + ToolCall(tool_name=ollama_tc["function"]["name"], arguments=ollama_tc["function"]["arguments"]) + ) + + message = ChatMessage.from_assistant(text=text, tool_calls=tool_calls) + + message.meta.update({key: value for key, value in response_dict.items() if key != "message"}) + return message + + @component class OllamaChatGenerator: """ @@ -40,6 +96,7 @@ def __init__( timeout: int = 120, keep_alive: Optional[Union[float, str]] = None, streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, + tools: Optional[List[Tool]] = None, ): """ :param model: @@ -52,9 +109,6 @@ def __init__( [Ollama docs](https://github.com/jmorganca/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values). :param timeout: The number of seconds before throwing a timeout error from the Ollama API. - :param streaming_callback: - A callback function that is called when a new token is received from the stream. - The callback function accepts StreamingChunk as an argument. :param keep_alive: The option that controls how long the model will stay loaded into memory following the request. If not set, it will use the default value from the Ollama (5 minutes). @@ -63,14 +117,24 @@ def __init__( - a number in seconds (such as 3600) - any negative number which will keep the model loaded in memory (e.g. -1 or "-1m") - '0' which will unload the model immediately after generating a response. + :param streaming_callback: + A callback function that is called when a new token is received from the stream. + The callback function accepts StreamingChunk as an argument. + :param tools: + A list of tools for which the model can prepare calls. + Not all models support tools. For a list of models compatible with tools, see the + [models page](https://ollama.com/search?c=tools). """ + _check_duplicate_tool_names(tools) + self.timeout = timeout self.generation_kwargs = generation_kwargs or {} self.url = url self.model = model self.keep_alive = keep_alive self.streaming_callback = streaming_callback + self.tools = tools self._client = Client(host=self.url, timeout=self.timeout) @@ -82,6 +146,7 @@ def to_dict(self) -> Dict[str, Any]: Dictionary with serialized data. """ callback_name = serialize_callable(self.streaming_callback) if self.streaming_callback else None + serialized_tools = [tool.to_dict() for tool in self.tools] if self.tools else None return default_to_dict( self, model=self.model, @@ -90,6 +155,7 @@ def to_dict(self) -> Dict[str, Any]: generation_kwargs=self.generation_kwargs, timeout=self.timeout, streaming_callback=callback_name, + tools=serialized_tools, ) @classmethod @@ -102,34 +168,14 @@ def from_dict(cls, data: Dict[str, Any]) -> "OllamaChatGenerator": :returns: Deserialized component. """ + deserialize_tools_inplace(data["init_parameters"], key="tools") + init_params = data.get("init_parameters", {}) - serialized_callback_handler = init_params.get("streaming_callback") - if serialized_callback_handler: + + if serialized_callback_handler := init_params.get("streaming_callback"): data["init_parameters"]["streaming_callback"] = deserialize_callable(serialized_callback_handler) return default_from_dict(cls, data) - def _message_to_dict(self, message: ChatMessage) -> Dict[str, str]: - return {"role": message.role.value, "content": message.text} - - def _build_message_from_ollama_response(self, ollama_response: ChatResponse) -> ChatMessage: - """ - Converts the non-streaming response from the Ollama API to a ChatMessage. - """ - response_dict = ollama_response.model_dump() - message = ChatMessage.from_assistant(response_dict["message"]["content"]) - message.meta.update({key: value for key, value in response_dict.items() if key != "message"}) - return message - - def _convert_to_streaming_response(self, chunks: List[StreamingChunk]) -> Dict[str, List[Any]]: - """ - Converts a list of chunks response required Haystack format. - """ - - replies = [ChatMessage.from_assistant("".join([c.content for c in chunks]))] - meta = {key: value for key, value in chunks[0].meta.items() if key != "message"} - - return {"replies": replies, "meta": [meta]} - def _build_chunk(self, chunk_response: Any) -> StreamingChunk: """ Converts the response from the Ollama API to a StreamingChunk. @@ -143,23 +189,28 @@ def _build_chunk(self, chunk_response: Any) -> StreamingChunk: chunk_message = StreamingChunk(content, meta) return chunk_message - def _handle_streaming_response(self, response) -> List[StreamingChunk]: + def _handle_streaming_response(self, response) -> Dict[str, List[Any]]: """ - Handles Streaming response cases + Handles streaming response and converts it to Haystack format """ chunks: List[StreamingChunk] = [] for chunk in response: - chunk_delta: StreamingChunk = self._build_chunk(chunk) + chunk_delta = self._build_chunk(chunk) chunks.append(chunk_delta) if self.streaming_callback is not None: self.streaming_callback(chunk_delta) - return chunks + + replies = [ChatMessage.from_assistant("".join([c.content for c in chunks]))] + meta = {key: value for key, value in chunks[0].meta.items() if key != "message"} + + return {"replies": replies, "meta": [meta]} @component.output_types(replies=List[ChatMessage]) def run( self, messages: List[ChatMessage], generation_kwargs: Optional[Dict[str, Any]] = None, + tools: Optional[List[Tool]] = None, ): """ Runs an Ollama Model on a given chat history. @@ -170,21 +221,35 @@ def run( Optional arguments to pass to the Ollama generation endpoint, such as temperature, top_p, etc. See the [Ollama docs](https://github.com/jmorganca/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values). - :param streaming_callback: - A callback function that will be called with each response chunk in streaming mode. + :param tools: + A list of tools for which the model can prepare calls. If set, it will override the `tools` parameter set + during component initialization. :returns: A dictionary with the following keys: - `replies`: The responses from the model """ generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})} stream = self.streaming_callback is not None - messages = [self._message_to_dict(message) for message in messages] + tools = tools or self.tools + _check_duplicate_tool_names(tools) + + if stream and tools: + msg = "Ollama does not support tools and streaming at the same time. Please choose one." + raise ValueError(msg) + + ollama_tools = [{"type": "function", "function": {**t.tool_spec}} for t in tools] if tools else None + + ollama_messages = [_convert_chatmessage_to_ollama_format(msg) for msg in messages] response = self._client.chat( - model=self.model, messages=messages, stream=stream, keep_alive=self.keep_alive, options=generation_kwargs + model=self.model, + messages=ollama_messages, + tools=ollama_tools, + stream=stream, + keep_alive=self.keep_alive, + options=generation_kwargs, ) if stream: - chunks: List[StreamingChunk] = self._handle_streaming_response(response) - return self._convert_to_streaming_response(chunks) + return self._handle_streaming_response(response) - return {"replies": [self._build_message_from_ollama_response(response)]} + return {"replies": [_convert_ollama_response_to_chatmessage(response)]} diff --git a/integrations/ollama/tests/test_chat_generator.py b/integrations/ollama/tests/test_chat_generator.py index b3df0fbf1..cb357027a 100644 --- a/integrations/ollama/tests/test_chat_generator.py +++ b/integrations/ollama/tests/test_chat_generator.py @@ -1,22 +1,158 @@ -from typing import List -from unittest.mock import Mock +import json +from unittest.mock import Mock, patch import pytest from haystack.components.generators.utils import print_streaming_chunk -from haystack.dataclasses import ChatMessage +from haystack.dataclasses import ( + ChatMessage, + ChatRole, + StreamingChunk, + TextContent, + ToolCall, +) +from haystack.tools import Tool from ollama._types import ChatResponse, ResponseError -from haystack_integrations.components.generators.ollama import OllamaChatGenerator +from haystack_integrations.components.generators.ollama.chat.chat_generator import ( + OllamaChatGenerator, + _convert_chatmessage_to_ollama_format, + _convert_ollama_response_to_chatmessage, +) @pytest.fixture -def chat_messages() -> List[ChatMessage]: - return [ - ChatMessage.from_user("Tell me about why Super Mario is the greatest superhero"), - ChatMessage.from_assistant( - "Super Mario has prevented Bowser from destroying the world", {"something": "something"} - ), - ] +def tools(): + tool_parameters = { + "type": "object", + "properties": {"city": {"type": "string"}}, + "required": ["city"], + } + tool = Tool( + name="weather", + description="useful to determine the weather in a given location", + parameters=tool_parameters, + function=lambda x: x, + ) + + return [tool] + + +def test_convert_chatmessage_to_ollama_format(): + message = ChatMessage.from_system("You are good assistant") + assert _convert_chatmessage_to_ollama_format(message) == { + "role": "system", + "content": "You are good assistant", + } + + message = ChatMessage.from_user("I have a question") + assert _convert_chatmessage_to_ollama_format(message) == { + "role": "user", + "content": "I have a question", + } + + message = ChatMessage.from_assistant(text="I have an answer", meta={"finish_reason": "stop"}) + assert _convert_chatmessage_to_ollama_format(message) == { + "role": "assistant", + "content": "I have an answer", + } + + message = ChatMessage.from_assistant( + tool_calls=[ToolCall(id="123", tool_name="weather", arguments={"city": "Paris"})] + ) + assert _convert_chatmessage_to_ollama_format(message) == { + "role": "assistant", + "tool_calls": [ + { + "type": "function", + "function": {"name": "weather", "arguments": {"city": "Paris"}}, + } + ], + } + + tool_result = json.dumps({"weather": "sunny", "temperature": "25"}) + message = ChatMessage.from_tool( + tool_result=tool_result, + origin=ToolCall(tool_name="weather", arguments={"city": "Paris"}), + ) + assert _convert_chatmessage_to_ollama_format(message) == { + "role": "tool", + "content": tool_result, + } + + +def test_convert_chatmessage_to_ollama_invalid(): + message = ChatMessage(_role=ChatRole.ASSISTANT, _content=[]) + with pytest.raises(ValueError): + _convert_chatmessage_to_ollama_format(message) + + message = ChatMessage( + _role=ChatRole.ASSISTANT, + _content=[ + TextContent(text="I have an answer"), + TextContent(text="I have another answer"), + ], + ) + with pytest.raises(ValueError): + _convert_chatmessage_to_ollama_format(message) + + +def test_convert_ollama_response_to_chatmessage(): + model = "some_model" + + ollama_response = ChatResponse( + model=model, + created_at="2023-12-12T14:13:43.416799Z", + message={"role": "assistant", "content": "Hello! How are you today?"}, + done=True, + total_duration=5191566416, + load_duration=2154458, + prompt_eval_count=26, + prompt_eval_duration=383809000, + eval_count=298, + eval_duration=4799921000, + ) + + observed = _convert_ollama_response_to_chatmessage(ollama_response) + + assert observed.role == "assistant" + assert observed.text == "Hello! How are you today?" + + +def test_convert_ollama_response_to_chatmessage_with_tools(): + model = "some_model" + + ollama_response = ChatResponse( + model=model, + created_at="2023-12-12T14:13:43.416799Z", + message={ + "role": "assistant", + "content": "", + "tool_calls": [ + { + "function": { + "name": "get_current_weather", + "arguments": {"format": "celsius", "location": "Paris, FR"}, + } + } + ], + }, + done=True, + total_duration=5191566416, + load_duration=2154458, + prompt_eval_count=26, + prompt_eval_duration=383809000, + eval_count=298, + eval_duration=4799921000, + ) + + observed = _convert_ollama_response_to_chatmessage(ollama_response) + + assert observed.role == "assistant" + assert observed.text == "" + assert observed.tool_call == ToolCall( + tool_name="get_current_weather", + arguments={"format": "celsius", "location": "Paris, FR"}, + ) class TestOllamaChatGenerator: @@ -26,15 +162,19 @@ def test_init_default(self): assert component.url == "http://localhost:11434" assert component.generation_kwargs == {} assert component.timeout == 120 + assert component.streaming_callback is None + assert component.tools is None assert component.keep_alive is None - def test_init(self): + def test_init(self, tools): component = OllamaChatGenerator( model="llama2", url="http://my-custom-endpoint:11434", generation_kwargs={"temperature": 0.5}, - keep_alive="10m", timeout=5, + keep_alive="10m", + streaming_callback=print_streaming_chunk, + tools=tools, ) assert component.model == "llama2" @@ -42,13 +182,29 @@ def test_init(self): assert component.generation_kwargs == {"temperature": 0.5} assert component.timeout == 5 assert component.keep_alive == "10m" + assert component.streaming_callback is print_streaming_chunk + assert component.tools == tools + + def test_init_fail_with_duplicate_tool_names(self, tools): + + duplicate_tools = [tools[0], tools[0]] + with pytest.raises(ValueError): + OllamaChatGenerator(tools=duplicate_tools) def test_to_dict(self): + tool = Tool( + name="name", + description="description", + parameters={"x": {"type": "string"}}, + function=print, + ) + component = OllamaChatGenerator( model="llama2", streaming_callback=print_streaming_chunk, url="custom_url", generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"}, + tools=[tool], keep_alive="5m", ) data = component.to_dict() @@ -57,14 +213,39 @@ def test_to_dict(self): "init_parameters": { "timeout": 120, "model": "llama2", - "keep_alive": "5m", "url": "custom_url", "streaming_callback": "haystack.components.generators.utils.print_streaming_chunk", - "generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"}, + "keep_alive": "5m", + "generation_kwargs": { + "max_tokens": 10, + "some_test_param": "test-params", + }, + "tools": [ + { + "type": "haystack.tools.tool.Tool", + "data": { + "description": "description", + "function": "builtins.print", + "name": "name", + "parameters": { + "x": { + "type": "string", + }, + }, + }, + }, + ], }, } def test_from_dict(self): + tool = Tool( + name="name", + description="description", + parameters={"x": {"type": "string"}}, + function=print, + ) + data = { "type": "haystack_integrations.components.generators.ollama.chat.chat_generator.OllamaChatGenerator", "init_parameters": { @@ -73,23 +254,50 @@ def test_from_dict(self): "url": "custom_url", "keep_alive": "5m", "streaming_callback": "haystack.components.generators.utils.print_streaming_chunk", - "generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"}, + "generation_kwargs": { + "max_tokens": 10, + "some_test_param": "test-params", + }, + "tools": [ + { + "type": "haystack.tools.tool.Tool", + "data": { + "description": "description", + "function": "builtins.print", + "name": "name", + "parameters": { + "x": { + "type": "string", + }, + }, + }, + }, + ], }, } component = OllamaChatGenerator.from_dict(data) assert component.model == "llama2" assert component.streaming_callback is print_streaming_chunk assert component.url == "custom_url" - assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"} assert component.keep_alive == "5m" + assert component.generation_kwargs == { + "max_tokens": 10, + "some_test_param": "test-params", + } + assert component.timeout == 120 + assert component.tools == [tool] - def test_build_message_from_ollama_response(self): - model = "some_model" + @patch("haystack_integrations.components.generators.ollama.chat.chat_generator.Client") + def test_run(self, mock_client): + generator = OllamaChatGenerator() - ollama_response = ChatResponse( - model=model, + mock_response = ChatResponse( + model="llama3.2", created_at="2023-12-12T14:13:43.416799Z", - message={"role": "assistant", "content": "Hello! How are you today?"}, + message={ + "role": "assistant", + "content": "Fine. How can I help you today?", + }, done=True, total_duration=5191566416, load_duration=2154458, @@ -99,14 +307,80 @@ def test_build_message_from_ollama_response(self): eval_duration=4799921000, ) - observed = OllamaChatGenerator(model=model)._build_message_from_ollama_response(ollama_response) + mock_client_instance = mock_client.return_value + mock_client_instance.chat.return_value = mock_response + + result = generator.run(messages=[ChatMessage.from_user("Hello! How are you today?")]) + + mock_client_instance.chat.assert_called_once_with( + model="orca-mini", + messages=[{"role": "user", "content": "Hello! How are you today?"}], + stream=False, + tools=None, + options={}, + keep_alive=None, + ) + + assert "replies" in result + assert len(result["replies"]) == 1 + assert result["replies"][0].text == "Fine. How can I help you today?" + assert result["replies"][0].role == "assistant" + + @patch("haystack_integrations.components.generators.ollama.chat.chat_generator.Client") + def test_run_streaming(self, mock_client): + streaming_callback_called = False + + def streaming_callback(_: StreamingChunk) -> None: + nonlocal streaming_callback_called + streaming_callback_called = True + + generator = OllamaChatGenerator(streaming_callback=streaming_callback) + + mock_response = iter( + [ + ChatResponse( + model="llama3.2", + created_at="2023-12-12T14:13:43.416799Z", + message={"role": "assistant", "content": "first chunk "}, + done=False, + ), + ChatResponse( + model="llama3.2", + created_at="2023-12-12T14:13:43.416799Z", + message={"role": "assistant", "content": "second chunk"}, + done=True, + total_duration=4883583458, + load_duration=1334875, + prompt_eval_count=26, + prompt_eval_duration=342546000, + eval_count=282, + eval_duration=4535599000, + ), + ] + ) + + mock_client_instance = mock_client.return_value + mock_client_instance.chat.return_value = mock_response + + result = generator.run(messages=[ChatMessage.from_user("irrelevant")]) + + assert streaming_callback_called + + assert "replies" in result + assert len(result["replies"]) == 1 + assert result["replies"][0].text == "first chunk second chunk" + assert result["replies"][0].role == "assistant" + + def test_run_fail_with_tools_and_streaming(self, tools): + component = OllamaChatGenerator(tools=tools, streaming_callback=print_streaming_chunk) - assert observed.role == "assistant" - assert observed.text == "Hello! How are you today?" + with pytest.raises(ValueError): + message = ChatMessage.from_user("irrelevant") + component.run([message]) @pytest.mark.integration - def test_run(self): - chat_generator = OllamaChatGenerator() + def test_live_run(self): + chat_generator = OllamaChatGenerator(model="llama3.2:3b") user_questions_and_assistant_answers = [ ("What's the capital of France?", "Paris"), @@ -125,45 +399,60 @@ def test_run(self): @pytest.mark.integration def test_run_with_chat_history(self): - chat_generator = OllamaChatGenerator() + chat_generator = OllamaChatGenerator(model="llama3.2:3b") - chat_history = [ + chat_messages = [ ChatMessage.from_user("What is the largest city in the United Kingdom by population?"), ChatMessage.from_assistant("London is the largest city in the United Kingdom by population"), ChatMessage.from_user("And what is the second largest?"), ] - response = chat_generator.run(chat_history) + response = chat_generator.run(chat_messages) assert isinstance(response, dict) assert isinstance(response["replies"], list) - assert "Manchester" in response["replies"][-1].text or "Glasgow" in response["replies"][-1].text + + assert any(city in response["replies"][-1].text for city in ["Manchester", "Birmingham", "Glasgow"]) @pytest.mark.integration def test_run_model_unavailable(self): - component = OllamaChatGenerator(model="Alistair_and_Stefano_are_great") + component = OllamaChatGenerator(model="unknown_model") with pytest.raises(ResponseError): - message = ChatMessage.from_user( - "Based on your infinite wisdom, can you tell me why Alistair and Stefano are so great?" - ) + message = ChatMessage.from_user("irrelevant") component.run([message]) @pytest.mark.integration def test_run_with_streaming(self): streaming_callback = Mock() - chat_generator = OllamaChatGenerator(streaming_callback=streaming_callback) + chat_generator = OllamaChatGenerator(model="llama3.2:3b", streaming_callback=streaming_callback) - chat_history = [ + chat_messages = [ ChatMessage.from_user("What is the largest city in the United Kingdom by population?"), ChatMessage.from_assistant("London is the largest city in the United Kingdom by population"), ChatMessage.from_user("And what is the second largest?"), ] - response = chat_generator.run(chat_history) + response = chat_generator.run(chat_messages) streaming_callback.assert_called() assert isinstance(response, dict) assert isinstance(response["replies"], list) - assert "Manchester" in response["replies"][-1].text or "Glasgow" in response["replies"][-1].text + assert any(city in response["replies"][-1].text for city in ["Manchester", "Birmingham", "Glasgow"]) + + @pytest.mark.integration + def test_run_with_tools(self, tools): + chat_generator = OllamaChatGenerator(model="llama3.2:3b", tools=tools) + + message = ChatMessage.from_user("What is the weather in Paris?") + response = chat_generator.run([message]) + + assert len(response["replies"]) == 1 + message = response["replies"][0] + + assert message.tool_calls + tool_call = message.tool_call + assert isinstance(tool_call, ToolCall) + assert tool_call.tool_name == "weather" + assert tool_call.arguments == {"city": "Paris"} diff --git a/integrations/ollama/tests/test_generator.py b/integrations/ollama/tests/test_generator.py index b02370234..e394f57f0 100644 --- a/integrations/ollama/tests/test_generator.py +++ b/integrations/ollama/tests/test_generator.py @@ -19,7 +19,7 @@ def test_run_capital_cities(self): ("What is the capital of Ghana?", "Accra"), ] - component = OllamaGenerator() + component = OllamaGenerator(model="llama3.2:3b") for prompt, answer in prompts_and_answers: results = component.run(prompt=prompt) @@ -147,7 +147,7 @@ def __call__(self, chunk): return chunk callback = Callback() - component = OllamaGenerator(streaming_callback=callback) + component = OllamaGenerator(model="llama3.2:3b", streaming_callback=callback) results = component.run(prompt="What's the capital of Netherlands?") assert len(results["replies"]) == 1 From b48c90bfaf48c412b3dbce2c73575684c301f1ee Mon Sep 17 00:00:00 2001 From: HaystackBot Date: Thu, 16 Jan 2025 10:57:55 +0000 Subject: [PATCH 180/229] Update the changelog --- integrations/ollama/CHANGELOG.md | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/integrations/ollama/CHANGELOG.md b/integrations/ollama/CHANGELOG.md index 7aeb70021..e4e8f3602 100644 --- a/integrations/ollama/CHANGELOG.md +++ b/integrations/ollama/CHANGELOG.md @@ -1,5 +1,12 @@ # Changelog +## [integrations/ollama-v2.2.0] - 2025-01-16 + +### 🚀 Features + +- Ollama - add support for tools (#1294) + + ## [integrations/ollama-v2.1.2] - 2024-12-18 ### 🐛 Bug Fixes From 7ceaeafade1d435ffd5a4eeaf5162c6677062a27 Mon Sep 17 00:00:00 2001 From: Stefano Fiorucci Date: Thu, 16 Jan 2025 14:34:58 +0100 Subject: [PATCH 181/229] vertex: handle function role removal (#1296) --- .../components/generators/google_vertex/chat/gemini.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/chat/gemini.py b/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/chat/gemini.py index 845e24f5f..516116321 100644 --- a/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/chat/gemini.py +++ b/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/chat/gemini.py @@ -215,7 +215,7 @@ def _message_to_part(self, message: ChatMessage) -> Part: return p elif message.is_from(ChatRole.SYSTEM) or message.is_from(ChatRole.ASSISTANT): return Part.from_text(message.text) - elif message.is_from(ChatRole.FUNCTION): + elif "FUNCTION" in ChatRole._member_names_ and message.is_from(ChatRole.FUNCTION): return Part.from_function_response(name=message.name, response=message.text) elif message.is_from(ChatRole.USER): return self._convert_part(message.text) @@ -227,14 +227,15 @@ def _message_to_content(self, message: ChatMessage) -> Content: part.function_call.args[k] = v elif message.is_from(ChatRole.SYSTEM) or message.is_from(ChatRole.ASSISTANT): part = Part.from_text(message.text) - elif message.is_from(ChatRole.FUNCTION): + elif "FUNCTION" in ChatRole._member_names_ and message.is_from(ChatRole.FUNCTION): part = Part.from_function_response(name=message.name, response=message.text) elif message.is_from(ChatRole.USER): part = self._convert_part(message.text) else: msg = f"Unsupported message role {message.role}" raise ValueError(msg) - role = "user" if message.is_from(ChatRole.USER) or message.is_from(ChatRole.FUNCTION) else "model" + + role = "model" if message.is_from(ChatRole.ASSISTANT) or message.is_from(ChatRole.SYSTEM) else "user" return Content(parts=[part], role=role) @component.output_types(replies=List[ChatMessage]) From 232eb3d1081a2ab7e8b34520451ed87ffd97b373 Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Thu, 16 Jan 2025 14:35:20 +0100 Subject: [PATCH 182/229] feat: Add Secret handling in OpenSearchDocumentStore (#1288) * Add Secret handling in OpenSearchDocumentStore * only serialize auth secrets when values are resolvable * Update integrations/opensearch/src/haystack_integrations/document_stores/opensearch/document_store.py Co-authored-by: tstadel <60758086+tstadel@users.noreply.github.com> * Fixes * Revert accidental commit * Special list of Secrets handling only, keep everything else as it was before * Small improvement * More simplifications --------- Co-authored-by: tstadel <60758086+tstadel@users.noreply.github.com> --- .../opensearch/document_store.py | 45 +++++++++++--- .../opensearch/tests/test_document_store.py | 60 +++++++++++++++++++ 2 files changed, 96 insertions(+), 9 deletions(-) diff --git a/integrations/opensearch/src/haystack_integrations/document_stores/opensearch/document_store.py b/integrations/opensearch/src/haystack_integrations/document_stores/opensearch/document_store.py index 6cb5295f0..7deaad285 100644 --- a/integrations/opensearch/src/haystack_integrations/document_stores/opensearch/document_store.py +++ b/integrations/opensearch/src/haystack_integrations/document_stores/opensearch/document_store.py @@ -9,6 +9,7 @@ from haystack.dataclasses import Document from haystack.document_stores.errors import DocumentStoreError, DuplicateDocumentError from haystack.document_stores.types import DuplicatePolicy +from haystack.utils.auth import Secret from opensearchpy import OpenSearch from opensearchpy.helpers import bulk @@ -45,7 +46,10 @@ def __init__( mappings: Optional[Dict[str, Any]] = None, settings: Optional[Dict[str, Any]] = DEFAULT_SETTINGS, create_index: bool = True, - http_auth: Any = None, + http_auth: Any = ( + Secret.from_env_var("OPENSEARCH_USERNAME", strict=False), # noqa: B008 + Secret.from_env_var("OPENSEARCH_PASSWORD", strict=False), # noqa: B008 + ), use_ssl: Optional[bool] = None, verify_certs: Optional[bool] = None, timeout: Optional[int] = None, @@ -79,6 +83,7 @@ def __init__( - a tuple of (username, password) - a list of [username, password] - a string of "username:password" + If not provided, will read values from OPENSEARCH_USERNAME and OPENSEARCH_PASSWORD environment variables. For AWS authentication with `Urllib3HttpConnection` pass an instance of `AWSAuth`. Defaults to None :param use_ssl: Whether to use SSL. Defaults to None @@ -97,6 +102,17 @@ def __init__( self._mappings = mappings or self._get_default_mappings() self._settings = settings self._create_index = create_index + self._http_auth_are_secrets = False + + # Handle authentication + if isinstance(http_auth, (tuple, list)) and len(http_auth) == 2: # noqa: PLR2004 + username, password = http_auth + if isinstance(username, Secret) and isinstance(password, Secret): + self._http_auth_are_secrets = True + username_val = username.resolve_value() + password_val = password.resolve_value() + http_auth = [username_val, password_val] if username_val and password_val else None + self._http_auth = http_auth self._use_ssl = use_ssl self._verify_certs = verify_certs @@ -174,15 +190,24 @@ def create_index( self.client.indices.create(index=index, body={"mappings": mappings, "settings": settings}) def to_dict(self) -> Dict[str, Any]: - # This is not the best solution to serialise this class but is the fastest to implement. - # Not all kwargs types can be serialised to text so this can fail. We must serialise each - # type explicitly to handle this properly. """ Serializes the component to a dictionary. :returns: Dictionary with serialized data. """ + # Handle http_auth serialization + if isinstance(self._http_auth, list) and self._http_auth_are_secrets: + # Recreate the Secret objects for serialization + http_auth = [ + Secret.from_env_var("OPENSEARCH_USERNAME", strict=False).to_dict(), + Secret.from_env_var("OPENSEARCH_PASSWORD", strict=False).to_dict(), + ] + elif isinstance(self._http_auth, AWSAuth): + http_auth = self._http_auth.to_dict() + else: + http_auth = self._http_auth + return default_to_dict( self, hosts=self._hosts, @@ -194,7 +219,7 @@ def to_dict(self) -> Dict[str, Any]: settings=self._settings, create_index=self._create_index, return_embedding=self._return_embedding, - http_auth=self._http_auth.to_dict() if isinstance(self._http_auth, AWSAuth) else self._http_auth, + http_auth=http_auth, use_ssl=self._use_ssl, verify_certs=self._verify_certs, timeout=self._timeout, @@ -208,14 +233,16 @@ def from_dict(cls, data: Dict[str, Any]) -> "OpenSearchDocumentStore": :param data: Dictionary to deserialize from. - :returns: Deserialized component. """ - if http_auth := data.get("init_parameters", {}).get("http_auth"): + init_params = data.get("init_parameters", {}) + if http_auth := init_params.get("http_auth"): if isinstance(http_auth, dict): - data["init_parameters"]["http_auth"] = AWSAuth.from_dict(http_auth) - + init_params["http_auth"] = AWSAuth.from_dict(http_auth) + elif isinstance(http_auth, (tuple, list)): + are_secrets = all(isinstance(item, dict) and "type" in item for item in http_auth) + init_params["http_auth"] = [Secret.from_dict(item) for item in http_auth] if are_secrets else http_auth return default_from_dict(cls, data) def count_documents(self) -> int: diff --git a/integrations/opensearch/tests/test_document_store.py b/integrations/opensearch/tests/test_document_store.py index 043f59891..82c21e6fe 100644 --- a/integrations/opensearch/tests/test_document_store.py +++ b/integrations/opensearch/tests/test_document_store.py @@ -263,6 +263,66 @@ def test_to_dict_aws_auth(self, _mock_opensearch_client, monkeypatch: pytest.Mon }, } + @patch("haystack_integrations.document_stores.opensearch.document_store.OpenSearch") + def test_init_with_env_var_secrets(self, _mock_opensearch_client, monkeypatch): + """Test the default initialization using environment variables""" + monkeypatch.setenv("OPENSEARCH_USERNAME", "user") + monkeypatch.setenv("OPENSEARCH_PASSWORD", "pass") + + document_store = OpenSearchDocumentStore(hosts="testhost") + assert document_store.client + _mock_opensearch_client.assert_called_once() + assert _mock_opensearch_client.call_args[1]["http_auth"] == ["user", "pass"] + + @patch("haystack_integrations.document_stores.opensearch.document_store.OpenSearch") + def test_init_with_missing_env_vars(self, _mock_opensearch_client): + """Test that auth is None when environment variables are missing""" + document_store = OpenSearchDocumentStore(hosts="testhost") + assert document_store.client + _mock_opensearch_client.assert_called_once() + assert _mock_opensearch_client.call_args[1]["http_auth"] is None + + @patch("haystack_integrations.document_stores.opensearch.document_store.OpenSearch") + def test_to_dict_with_env_var_secrets(self, _mock_opensearch_client, monkeypatch): + """Test serialization with environment variables""" + monkeypatch.setenv("OPENSEARCH_USERNAME", "user") + monkeypatch.setenv("OPENSEARCH_PASSWORD", "pass") + + document_store = OpenSearchDocumentStore(hosts="testhost") + serialized = document_store.to_dict() + + assert "http_auth" in serialized["init_parameters"] + auth = serialized["init_parameters"]["http_auth"] + assert isinstance(auth, list) + assert len(auth) == 2 + # Check that we have two Secret dictionaries with correct env vars + assert auth[0]["type"] == "env_var" + assert auth[0]["env_vars"] == ["OPENSEARCH_USERNAME"] + assert auth[1]["type"] == "env_var" + assert auth[1]["env_vars"] == ["OPENSEARCH_PASSWORD"] + + @patch("haystack_integrations.document_stores.opensearch.document_store.OpenSearch") + def test_from_dict_with_env_var_secrets(self, _mock_opensearch_client, monkeypatch): + """Test deserialization with environment variables""" + # Set environment variables so the secrets resolve properly + monkeypatch.setenv("OPENSEARCH_USERNAME", "user") + monkeypatch.setenv("OPENSEARCH_PASSWORD", "pass") + + data = { + "type": "haystack_integrations.document_stores.opensearch.document_store.OpenSearchDocumentStore", + "init_parameters": { + "hosts": "testhost", + "http_auth": [ + {"type": "env_var", "env_vars": ["OPENSEARCH_USERNAME"], "strict": False}, + {"type": "env_var", "env_vars": ["OPENSEARCH_PASSWORD"], "strict": False}, + ], + }, + } + document_store = OpenSearchDocumentStore.from_dict(data) + assert document_store.client + _mock_opensearch_client.assert_called_once() + assert _mock_opensearch_client.call_args[1]["http_auth"] == ["user", "pass"] + @pytest.mark.integration class TestDocumentStore(DocumentStoreBaseTests): From 6d836b01be0773e71833b30779c867fc5349fa28 Mon Sep 17 00:00:00 2001 From: HaystackBot Date: Thu, 16 Jan 2025 13:37:15 +0000 Subject: [PATCH 183/229] Update the changelog --- integrations/google_vertex/CHANGELOG.md | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/integrations/google_vertex/CHANGELOG.md b/integrations/google_vertex/CHANGELOG.md index 7cf633ebd..a544507a1 100644 --- a/integrations/google_vertex/CHANGELOG.md +++ b/integrations/google_vertex/CHANGELOG.md @@ -1,5 +1,11 @@ # Changelog +## [unreleased] + +### 🌀 Miscellaneous + +- Handle function role removal (#1296) + ## [integrations/google_vertex-v4.0.1] - 2024-12-19 ### 🐛 Bug Fixes From 203182b22a04e295c1a413a728d93fe7dd777cce Mon Sep 17 00:00:00 2001 From: HaystackBot Date: Thu, 16 Jan 2025 13:52:48 +0000 Subject: [PATCH 184/229] Update the changelog --- integrations/opensearch/CHANGELOG.md | 55 +++++++++++++++++++++++++++- 1 file changed, 53 insertions(+), 2 deletions(-) diff --git a/integrations/opensearch/CHANGELOG.md b/integrations/opensearch/CHANGELOG.md index fef7b4bc3..a9870cd71 100644 --- a/integrations/opensearch/CHANGELOG.md +++ b/integrations/opensearch/CHANGELOG.md @@ -1,5 +1,12 @@ # Changelog +## [unreleased] + +### 🚀 Features + +- Add Secret handling in OpenSearchDocumentStore (#1288) + + ## [integrations/opensearch-v1.2.0] - 2024-12-12 ### 🧹 Chores @@ -85,6 +92,8 @@ - Fix order of API docs (#447) +This PR will also push the docs to Readme + ### 📚 Documentation - Update category slug (#442) @@ -100,8 +109,32 @@ - Generate API docs (#324) - Make tests show coverage (#566) + +* make tests show coverage + +* rm duplicate coverage definition - Refactor tests (#574) + +* first refactorings + +* separate unit tests in pgvector + +* small change to weaviate + +* fix format + +* usefixtures when possible - Fix opensearch errors bulk write (#594) + +* fix(opensearch): bulk error without create key + +* Add test + +* Comment why + +--------- + +Co-authored-by: Corentin Meyer - Remove references to Python 3.7 (#601) - [Elasticsearch] fix: Filters not working with metadata that contain a space or capitalization (#639) - Chore: add license classifiers (#680) @@ -120,14 +153,18 @@ ### 🌀 Miscellaneous -- Fix opensearch test badge (#97) +- Update README.md (#97) - Move package under haystack_integrations/* (#212) +* move package under haystack_integrations/* + +* ignore types + ## [integrations/opensearch-v0.1.1] - 2023-12-05 ### 🐛 Bug Fixes -- Document Stores: fix protocol import (#77) +- Fix import and increase version (#77) ## [integrations/opensearch-v0.1.0] - 2023-12-04 @@ -139,6 +176,20 @@ - Remove Document Store decorator (#76) +* remove decorator + +* Update integrations/elasticsearch/src/elasticsearch_haystack/__about__.py + +Co-authored-by: Massimiliano Pippi + +* Update integrations/opensearch/src/opensearch_haystack/__about__.py + +Co-authored-by: Massimiliano Pippi + +--------- + +Co-authored-by: Massimiliano Pippi + ## [integrations/opensearch-v0.0.2] - 2023-11-30 ### 🚀 Features From ee543c1c1f84104b3706d991ee298f404aadfa16 Mon Sep 17 00:00:00 2001 From: Stefano Fiorucci Date: Thu, 16 Jan 2025 15:26:16 +0100 Subject: [PATCH 185/229] chore: llama.cpp - gently handle the removal of ChatMessage.from_function (#1298) --- integrations/llama_cpp/pyproject.toml | 2 +- .../llama_cpp/chat/chat_generator.py | 6 +-- .../llama_cpp/tests/test_chat_generator.py | 39 ++++++++++--------- 3 files changed, 22 insertions(+), 25 deletions(-) diff --git a/integrations/llama_cpp/pyproject.toml b/integrations/llama_cpp/pyproject.toml index a33434e1b..2f0989e3b 100644 --- a/integrations/llama_cpp/pyproject.toml +++ b/integrations/llama_cpp/pyproject.toml @@ -26,7 +26,7 @@ classifiers = [ "Programming Language :: Python :: Implementation :: CPython", "Programming Language :: Python :: Implementation :: PyPy", ] -dependencies = ["haystack-ai", "llama-cpp-python>=0.2.87"] +dependencies = ["haystack-ai>=2.9.0", "llama-cpp-python>=0.2.87"] [project.urls] Documentation = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/llama_cpp#readme" diff --git a/integrations/llama_cpp/src/haystack_integrations/components/generators/llama_cpp/chat/chat_generator.py b/integrations/llama_cpp/src/haystack_integrations/components/generators/llama_cpp/chat/chat_generator.py index d2150f61f..b3e395bdd 100644 --- a/integrations/llama_cpp/src/haystack_integrations/components/generators/llama_cpp/chat/chat_generator.py +++ b/integrations/llama_cpp/src/haystack_integrations/components/generators/llama_cpp/chat/chat_generator.py @@ -138,11 +138,7 @@ def run(self, messages: List[ChatMessage], generation_kwargs: Optional[Dict[str, name = tool_calls[0]["function"]["name"] reply = ChatMessage.from_assistant(choice["message"]["content"], meta=meta) - if name: - if hasattr(reply, "_name"): - reply._name = name # new ChatMessage - elif hasattr(reply, "name"): - reply.name = name # legacy ChatMessage + reply._name = name or None replies.append(reply) return {"replies": replies} diff --git a/integrations/llama_cpp/tests/test_chat_generator.py b/integrations/llama_cpp/tests/test_chat_generator.py index 87639f684..7f5125554 100644 --- a/integrations/llama_cpp/tests/test_chat_generator.py +++ b/integrations/llama_cpp/tests/test_chat_generator.py @@ -40,12 +40,12 @@ def test_convert_message_to_llamacpp_format(): message = ChatMessage.from_user("I have a question") assert _convert_message_to_llamacpp_format(message) == {"role": "user", "content": "I have a question"} - message = ChatMessage.from_function("Function call", "function_name") - converted_message = _convert_message_to_llamacpp_format(message) - - assert converted_message["role"] in ("function", "tool") - assert converted_message["name"] == "function_name" - assert converted_message["content"] == "Function call" + if hasattr(ChatMessage, "from_function"): + message = ChatMessage.from_function("Function call", "function_name") + converted_message = _convert_message_to_llamacpp_format(message) + assert converted_message["role"] in ("function", "tool") + assert converted_message["name"] == "function_name" + assert converted_message["content"] == "Function call" class TestLlamaCppChatGenerator: @@ -420,19 +420,20 @@ def test_function_call_and_execute(self, generator): assert "tool_calls" in first_reply.meta tool_calls = first_reply.meta["tool_calls"] - for tool_call in tool_calls: - function_name = tool_call["function"]["name"] - function_args = json.loads(tool_call["function"]["arguments"]) - assert function_name in available_functions - function_response = available_functions[function_name](**function_args) - function_message = ChatMessage.from_function(function_response, function_name) - messages.append(function_message) - - second_response = generator.run(messages=messages) - assert "replies" in second_response - assert len(second_response["replies"]) > 0 - assert any("San Francisco" in reply.text for reply in second_response["replies"]) - assert any("72" in reply.text for reply in second_response["replies"]) + if hasattr(ChatMessage, "from_function"): + for tool_call in tool_calls: + function_name = tool_call["function"]["name"] + function_args = json.loads(tool_call["function"]["arguments"]) + assert function_name in available_functions + function_response = available_functions[function_name](**function_args) + function_message = ChatMessage.from_function(function_response, function_name) + messages.append(function_message) + + second_response = generator.run(messages=messages) + assert "replies" in second_response + assert len(second_response["replies"]) > 0 + assert any("San Francisco" in reply.text for reply in second_response["replies"]) + assert any("72" in reply.text for reply in second_response["replies"]) class TestLlamaCppChatGeneratorChatML: From dc0d1530e047abbe35346e00839d074aefb91b94 Mon Sep 17 00:00:00 2001 From: HaystackBot Date: Thu, 16 Jan 2025 14:27:23 +0000 Subject: [PATCH 186/229] Update the changelog --- integrations/llama_cpp/CHANGELOG.md | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/integrations/llama_cpp/CHANGELOG.md b/integrations/llama_cpp/CHANGELOG.md index 930486a0d..09d79d234 100644 --- a/integrations/llama_cpp/CHANGELOG.md +++ b/integrations/llama_cpp/CHANGELOG.md @@ -1,5 +1,12 @@ # Changelog +## [integrations/llama_cpp-v0.4.4] - 2025-01-16 + +### 🧹 Chores + +- Llama.cpp - gently handle the removal of ChatMessage.from_function (#1298) + + ## [integrations/llama_cpp-v0.4.3] - 2024-12-19 ### 🐛 Bug Fixes From fecf0e24e55a35a99f2161008b24602a12a8ec32 Mon Sep 17 00:00:00 2001 From: Stefano Fiorucci Date: Thu, 16 Jan 2025 15:36:45 +0100 Subject: [PATCH 187/229] manually fix changelog (#1299) --- integrations/opensearch/CHANGELOG.md | 53 +++------------------------- 1 file changed, 4 insertions(+), 49 deletions(-) diff --git a/integrations/opensearch/CHANGELOG.md b/integrations/opensearch/CHANGELOG.md index a9870cd71..55bb6e962 100644 --- a/integrations/opensearch/CHANGELOG.md +++ b/integrations/opensearch/CHANGELOG.md @@ -1,12 +1,11 @@ # Changelog -## [unreleased] +## [integrations/opensearch-v1.3.0] - 2025-01-16 ### 🚀 Features - Add Secret handling in OpenSearchDocumentStore (#1288) - ## [integrations/opensearch-v1.2.0] - 2024-12-12 ### 🧹 Chores @@ -92,8 +91,6 @@ - Fix order of API docs (#447) -This PR will also push the docs to Readme - ### 📚 Documentation - Update category slug (#442) @@ -109,32 +106,8 @@ This PR will also push the docs to Readme - Generate API docs (#324) - Make tests show coverage (#566) - -* make tests show coverage - -* rm duplicate coverage definition - Refactor tests (#574) - -* first refactorings - -* separate unit tests in pgvector - -* small change to weaviate - -* fix format - -* usefixtures when possible - Fix opensearch errors bulk write (#594) - -* fix(opensearch): bulk error without create key - -* Add test - -* Comment why - ---------- - -Co-authored-by: Corentin Meyer - Remove references to Python 3.7 (#601) - [Elasticsearch] fix: Filters not working with metadata that contain a space or capitalization (#639) - Chore: add license classifiers (#680) @@ -153,18 +126,14 @@ Co-authored-by: Corentin Meyer ### 🌀 Miscellaneous -- Update README.md (#97) +- Fix opensearch test badge (#97) - Move package under haystack_integrations/* (#212) -* move package under haystack_integrations/* - -* ignore types - ## [integrations/opensearch-v0.1.1] - 2023-12-05 ### 🐛 Bug Fixes -- Fix import and increase version (#77) +- Document Stores: fix protocol import (#77) ## [integrations/opensearch-v0.1.0] - 2023-12-04 @@ -176,20 +145,6 @@ Co-authored-by: Corentin Meyer - Remove Document Store decorator (#76) -* remove decorator - -* Update integrations/elasticsearch/src/elasticsearch_haystack/__about__.py - -Co-authored-by: Massimiliano Pippi - -* Update integrations/opensearch/src/opensearch_haystack/__about__.py - -Co-authored-by: Massimiliano Pippi - ---------- - -Co-authored-by: Massimiliano Pippi - ## [integrations/opensearch-v0.0.2] - 2023-11-30 ### 🚀 Features @@ -206,4 +161,4 @@ Co-authored-by: Massimiliano Pippi - [OpenSearch] add document store, BM25Retriever and EmbeddingRetriever (#68) - + \ No newline at end of file From 501f31cf37761bcaa0eff3f1fcc3d4e9c9806cdc Mon Sep 17 00:00:00 2001 From: Stefano Fiorucci Date: Thu, 16 Jan 2025 17:55:34 +0100 Subject: [PATCH 188/229] chore: google-ai - gently handle the removal of function role (#1297) --- integrations/google_ai/pyproject.toml | 4 +- .../generators/google_ai/chat/gemini.py | 8 ++-- .../tests/generators/chat/test_chat_gemini.py | 40 ++++++++++--------- 3 files changed, 27 insertions(+), 25 deletions(-) diff --git a/integrations/google_ai/pyproject.toml b/integrations/google_ai/pyproject.toml index 9a4a070e7..4da1db297 100644 --- a/integrations/google_ai/pyproject.toml +++ b/integrations/google_ai/pyproject.toml @@ -21,7 +21,7 @@ classifiers = [ "Programming Language :: Python :: Implementation :: CPython", "Programming Language :: Python :: Implementation :: PyPy", ] -dependencies = ["haystack-ai", "google-generativeai>=0.3.1"] +dependencies = ["haystack-ai>=2.9.0", "google-generativeai>=0.3.1"] [project.urls] Documentation = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/google_ai_haystack#readme" @@ -56,7 +56,7 @@ cov = ["test-cov", "cov-report"] cov-retry = ["test-cov-retry", "cov-report"] docs = ["pydoc-markdown pydoc/config.yml"] [[tool.hatch.envs.all.matrix]] -python = ["3.8", "3.9", "3.10", "3.11"] +python = ["3.9", "3.10", "3.11"] [tool.hatch.envs.lint] installer = "uv" diff --git a/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/chat/gemini.py b/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/chat/gemini.py index 69f168a6b..2addaca7a 100644 --- a/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/chat/gemini.py +++ b/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/chat/gemini.py @@ -242,12 +242,12 @@ def _message_to_part(self, message: ChatMessage) -> Part: p = Part() p.text = message.text return p - elif message.is_from(ChatRole.FUNCTION): + elif "FUNCTION" in ChatRole._member_names_ and message.is_from(ChatRole.FUNCTION): p = Part() p.function_response.name = message.name p.function_response.response = message.text return p - elif "TOOL" in ChatRole._member_names_ and message.is_from(ChatRole.TOOL): + elif message.is_from(ChatRole.TOOL): p = Part() p.function_response.name = message.tool_call_result.origin.tool_name p.function_response.response = message.tool_call_result.result @@ -265,13 +265,13 @@ def _message_to_content(self, message: ChatMessage) -> Content: elif message.is_from(ChatRole.SYSTEM) or message.is_from(ChatRole.ASSISTANT): part = Part() part.text = message.text - elif message.is_from(ChatRole.FUNCTION): + elif "FUNCTION" in ChatRole._member_names_ and message.is_from(ChatRole.FUNCTION): part = Part() part.function_response.name = message.name part.function_response.response = message.text elif message.is_from(ChatRole.USER): part = self._convert_part(message.text) - elif "TOOL" in ChatRole._member_names_ and message.is_from(ChatRole.TOOL): + elif message.is_from(ChatRole.TOOL): part = Part() part.function_response.name = message.tool_call_result.origin.tool_name part.function_response.response = message.tool_call_result.result diff --git a/integrations/google_ai/tests/generators/chat/test_chat_gemini.py b/integrations/google_ai/tests/generators/chat/test_chat_gemini.py index 0683bf21a..ce12d4a4d 100644 --- a/integrations/google_ai/tests/generators/chat/test_chat_gemini.py +++ b/integrations/google_ai/tests/generators/chat/test_chat_gemini.py @@ -227,16 +227,17 @@ def get_current_weather(location: str, unit: str = "celsius"): # noqa: ARG001 assert json.loads(chat_message.text) == {"location": "Berlin", "unit": "celsius"} weather = get_current_weather(**json.loads(chat_message.text)) - messages += response["replies"] + [ChatMessage.from_function(weather, name="get_current_weather")] - response = gemini_chat.run(messages=messages) - assert "replies" in response - assert len(response["replies"]) > 0 - assert all(reply.role == ChatRole.ASSISTANT for reply in response["replies"]) + if hasattr(ChatMessage, "from_function"): + messages += response["replies"] + [ChatMessage.from_function(weather, name="get_current_weather")] + response = gemini_chat.run(messages=messages) + assert "replies" in response + assert len(response["replies"]) > 0 + assert all(reply.role == ChatRole.ASSISTANT for reply in response["replies"]) - # check the second response is not a function call - chat_message = response["replies"][0] - assert "function_call" not in chat_message.meta - assert isinstance(chat_message.text, str) + # check the second response is not a function call + chat_message = response["replies"][0] + assert "function_call" not in chat_message.meta + assert isinstance(chat_message.text, str) @pytest.mark.skipif(not os.environ.get("GOOGLE_API_KEY", None), reason="GOOGLE_API_KEY env var not set") @@ -273,16 +274,17 @@ def get_current_weather(location: str, unit: str = "celsius"): # noqa: ARG001 assert json.loads(chat_message.text) == {"location": "Berlin", "unit": "celsius"} weather = get_current_weather(**json.loads(chat_message.text)) - messages += response["replies"] + [ChatMessage.from_function(weather, name="get_current_weather")] - response = gemini_chat.run(messages=messages) - assert "replies" in response - assert len(response["replies"]) > 0 - assert all(reply.role == ChatRole.ASSISTANT for reply in response["replies"]) - - # check the second response is not a function call - chat_message = response["replies"][0] - assert "function_call" not in chat_message.meta - assert isinstance(chat_message.text, str) + if hasattr(ChatMessage, "from_function"): + messages += response["replies"] + [ChatMessage.from_function(weather, name="get_current_weather")] + response = gemini_chat.run(messages=messages) + assert "replies" in response + assert len(response["replies"]) > 0 + assert all(reply.role == ChatRole.ASSISTANT for reply in response["replies"]) + + # check the second response is not a function call + chat_message = response["replies"][0] + assert "function_call" not in chat_message.meta + assert isinstance(chat_message.text, str) @pytest.mark.skipif(not os.environ.get("GOOGLE_API_KEY", None), reason="GOOGLE_API_KEY env var not set") From 72989f4461b863d8be6051dbed5fdefdd6d870ed Mon Sep 17 00:00:00 2001 From: HaystackBot Date: Thu, 16 Jan 2025 16:57:16 +0000 Subject: [PATCH 189/229] Update the changelog --- integrations/google_ai/CHANGELOG.md | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/integrations/google_ai/CHANGELOG.md b/integrations/google_ai/CHANGELOG.md index 71cdf4e74..5004682ee 100644 --- a/integrations/google_ai/CHANGELOG.md +++ b/integrations/google_ai/CHANGELOG.md @@ -1,5 +1,12 @@ # Changelog +## [integrations/google_ai-v4.1.0] - 2025-01-16 + +### 🧹 Chores + +- Google-ai - gently handle the removal of function role (#1297) + + ## [integrations/google_ai-v4.0.1] - 2024-12-19 ### 🐛 Bug Fixes From c33e81cef274c0a6a9da2ea755cb2477e19006b9 Mon Sep 17 00:00:00 2001 From: Stefano Fiorucci Date: Fri, 17 Jan 2025 11:10:50 +0100 Subject: [PATCH 190/229] chore: inherit from `FilterDocumentsTestWithDataframe` in Document Stores (#1290) * use FilterDocumentsTestWithDataframe testing class * lint * weaviate fix * another fix * missing import --- .../astra/tests/test_document_store.py | 4 +- .../azure_ai_search/tests/conftest.py | 1 - .../tests/test_document_store.py | 3 +- .../chroma/tests/test_document_store.py | 12 --- .../tests/test_document_store.py | 4 +- .../tests/test_document_store.py | 4 +- .../opensearch/tests/test_document_store.py | 4 +- integrations/pgvector/tests/test_filters.py | 4 +- integrations/pinecone/tests/test_filters.py | 3 +- integrations/qdrant/tests/test_filters.py | 4 +- .../weaviate/tests/test_document_store.py | 82 ++++++------------- 11 files changed, 39 insertions(+), 86 deletions(-) diff --git a/integrations/astra/tests/test_document_store.py b/integrations/astra/tests/test_document_store.py index ef00b6b25..fbb233b91 100644 --- a/integrations/astra/tests/test_document_store.py +++ b/integrations/astra/tests/test_document_store.py @@ -9,7 +9,7 @@ from haystack import Document from haystack.document_stores.errors import MissingDocumentError from haystack.document_stores.types import DuplicatePolicy -from haystack.testing.document_store import DocumentStoreBaseTests +from haystack.testing.document_store import DocumentStoreBaseTests, FilterDocumentsTestWithDataframe from haystack_integrations.document_stores.astra import AstraDocumentStore @@ -47,7 +47,7 @@ def test_to_dict(mock_auth): # noqa os.environ.get("ASTRA_DB_APPLICATION_TOKEN", "") == "", reason="ASTRA_DB_APPLICATION_TOKEN env var not set" ) @pytest.mark.skipif(os.environ.get("ASTRA_DB_API_ENDPOINT", "") == "", reason="ASTRA_DB_API_ENDPOINT env var not set") -class TestDocumentStore(DocumentStoreBaseTests): +class TestDocumentStore(DocumentStoreBaseTests, FilterDocumentsTestWithDataframe): """ Common test cases will be provided by `DocumentStoreBaseTests` but you can add more to this class. diff --git a/integrations/azure_ai_search/tests/conftest.py b/integrations/azure_ai_search/tests/conftest.py index 02742031c..2861c71d6 100644 --- a/integrations/azure_ai_search/tests/conftest.py +++ b/integrations/azure_ai_search/tests/conftest.py @@ -10,7 +10,6 @@ from haystack_integrations.document_stores.azure_ai_search import AzureAISearchDocumentStore - # This is the approximate time in seconds it takes for the documents to be available in Azure Search index SLEEP_TIME_IN_SECONDS = 10 MAX_WAIT_TIME_FOR_INDEX_DELETION = 10 diff --git a/integrations/azure_ai_search/tests/test_document_store.py b/integrations/azure_ai_search/tests/test_document_store.py index 1bcd967c6..ced1347ae 100644 --- a/integrations/azure_ai_search/tests/test_document_store.py +++ b/integrations/azure_ai_search/tests/test_document_store.py @@ -14,6 +14,7 @@ CountDocumentsTest, DeleteDocumentsTest, FilterDocumentsTest, + FilterDocumentsTestWithDataframe, WriteDocumentsTest, ) from haystack.utils.auth import EnvVarSecret, Secret @@ -155,7 +156,7 @@ def _random_embeddings(n): ], indirect=True, ) -class TestFilters(FilterDocumentsTest): +class TestFilters(FilterDocumentsTest, FilterDocumentsTestWithDataframe): # Overriding to change "date" to compatible ISO 8601 format # and remove incompatible fields (dataframes) for Azure search index diff --git a/integrations/chroma/tests/test_document_store.py b/integrations/chroma/tests/test_document_store.py index ed815251e..889467fc0 100644 --- a/integrations/chroma/tests/test_document_store.py +++ b/integrations/chroma/tests/test_document_store.py @@ -406,18 +406,6 @@ def test_nested_logical_filters(self, document_store: ChromaDocumentStore, filte ], ) - @pytest.mark.skip(reason="Filter on dataframe contents is not supported.") - def test_comparison_equal_with_dataframe( - self, document_store: ChromaDocumentStore, filterable_docs: List[Document] - ): - pass - - @pytest.mark.skip(reason="Filter on dataframe contents is not supported.") - def test_comparison_not_equal_with_dataframe( - self, document_store: ChromaDocumentStore, filterable_docs: List[Document] - ): - pass - @pytest.mark.skip(reason="Chroma does not support comparison with null values") def test_comparison_equal_with_none(self, document_store, filterable_docs): pass diff --git a/integrations/elasticsearch/tests/test_document_store.py b/integrations/elasticsearch/tests/test_document_store.py index d636ff027..32f388a7c 100644 --- a/integrations/elasticsearch/tests/test_document_store.py +++ b/integrations/elasticsearch/tests/test_document_store.py @@ -11,7 +11,7 @@ from haystack.dataclasses.document import Document from haystack.document_stores.errors import DocumentStoreError, DuplicateDocumentError from haystack.document_stores.types import DuplicatePolicy -from haystack.testing.document_store import DocumentStoreBaseTests +from haystack.testing.document_store import DocumentStoreBaseTests, FilterDocumentsTestWithDataframe from haystack_integrations.document_stores.elasticsearch import ElasticsearchDocumentStore @@ -70,7 +70,7 @@ def test_from_dict(_mock_elasticsearch_client): @pytest.mark.integration -class TestDocumentStore(DocumentStoreBaseTests): +class TestDocumentStore(DocumentStoreBaseTests, FilterDocumentsTestWithDataframe): """ Common test cases will be provided by `DocumentStoreBaseTests` but you can add more to this class. diff --git a/integrations/mongodb_atlas/tests/test_document_store.py b/integrations/mongodb_atlas/tests/test_document_store.py index 6c0ac191e..99817ef90 100644 --- a/integrations/mongodb_atlas/tests/test_document_store.py +++ b/integrations/mongodb_atlas/tests/test_document_store.py @@ -9,7 +9,7 @@ from haystack.dataclasses.document import ByteStream, Document from haystack.document_stores.errors import DuplicateDocumentError from haystack.document_stores.types import DuplicatePolicy -from haystack.testing.document_store import DocumentStoreBaseTests +from haystack.testing.document_store import DocumentStoreBaseTests, FilterDocumentsTestWithDataframe from haystack.utils import Secret from pandas import DataFrame from pymongo import MongoClient @@ -35,7 +35,7 @@ def test_init_is_lazy(_mock_client): reason="No MongoDB Atlas connection string provided", ) @pytest.mark.integration -class TestDocumentStore(DocumentStoreBaseTests): +class TestDocumentStore(DocumentStoreBaseTests, FilterDocumentsTestWithDataframe): @pytest.fixture def document_store(self): database_name = "haystack_integration_test" diff --git a/integrations/opensearch/tests/test_document_store.py b/integrations/opensearch/tests/test_document_store.py index 82c21e6fe..41bb99cf6 100644 --- a/integrations/opensearch/tests/test_document_store.py +++ b/integrations/opensearch/tests/test_document_store.py @@ -9,7 +9,7 @@ from haystack.dataclasses.document import Document from haystack.document_stores.errors import DocumentStoreError, DuplicateDocumentError from haystack.document_stores.types import DuplicatePolicy -from haystack.testing.document_store import DocumentStoreBaseTests +from haystack.testing.document_store import DocumentStoreBaseTests, FilterDocumentsTestWithDataframe from haystack.utils.auth import Secret from opensearchpy.exceptions import RequestError @@ -325,7 +325,7 @@ def test_from_dict_with_env_var_secrets(self, _mock_opensearch_client, monkeypat @pytest.mark.integration -class TestDocumentStore(DocumentStoreBaseTests): +class TestDocumentStore(DocumentStoreBaseTests, FilterDocumentsTestWithDataframe): """ Common test cases will be provided by `DocumentStoreBaseTests` but you can add more to this class. diff --git a/integrations/pgvector/tests/test_filters.py b/integrations/pgvector/tests/test_filters.py index 08e693471..de5646d2f 100644 --- a/integrations/pgvector/tests/test_filters.py +++ b/integrations/pgvector/tests/test_filters.py @@ -2,7 +2,7 @@ import pytest from haystack.dataclasses.document import Document -from haystack.testing.document_store import FilterDocumentsTest +from haystack.testing.document_store import FilterDocumentsTest, FilterDocumentsTestWithDataframe from pandas import DataFrame from psycopg.sql import SQL from psycopg.types.json import Jsonb @@ -17,7 +17,7 @@ @pytest.mark.integration -class TestFilters(FilterDocumentsTest): +class TestFilters(FilterDocumentsTest, FilterDocumentsTestWithDataframe): def assert_documents_are_equal(self, received: List[Document], expected: List[Document]): """ This overrides the default assert_documents_are_equal from FilterDocumentsTest. diff --git a/integrations/pinecone/tests/test_filters.py b/integrations/pinecone/tests/test_filters.py index 40c9cdb10..012418446 100644 --- a/integrations/pinecone/tests/test_filters.py +++ b/integrations/pinecone/tests/test_filters.py @@ -5,12 +5,13 @@ from haystack.dataclasses.document import Document from haystack.testing.document_store import ( FilterDocumentsTest, + FilterDocumentsTestWithDataframe, ) @pytest.mark.integration @pytest.mark.skipif("PINECONE_API_KEY" not in os.environ, reason="PINECONE_API_KEY not set") -class TestFilters(FilterDocumentsTest): +class TestFilters(FilterDocumentsTest, FilterDocumentsTestWithDataframe): def assert_documents_are_equal(self, received: List[Document], expected: List[Document]): for doc in received: # Pinecone seems to convert integers to floats (undocumented behavior) diff --git a/integrations/qdrant/tests/test_filters.py b/integrations/qdrant/tests/test_filters.py index 1afd2b7f3..78098a78b 100644 --- a/integrations/qdrant/tests/test_filters.py +++ b/integrations/qdrant/tests/test_filters.py @@ -2,14 +2,14 @@ import pytest from haystack import Document -from haystack.testing.document_store import FilterDocumentsTest +from haystack.testing.document_store import FilterDocumentsTest, FilterDocumentsTestWithDataframe from haystack.utils.filters import FilterError from qdrant_client.http import models from haystack_integrations.document_stores.qdrant import QdrantDocumentStore -class TestQdrantStoreBaseTests(FilterDocumentsTest): +class TestQdrantStoreBaseTests(FilterDocumentsTest, FilterDocumentsTestWithDataframe): @pytest.fixture def document_store(self) -> QdrantDocumentStore: return QdrantDocumentStore( diff --git a/integrations/weaviate/tests/test_document_store.py b/integrations/weaviate/tests/test_document_store.py index 00af322e4..cf9927a3d 100644 --- a/integrations/weaviate/tests/test_document_store.py +++ b/integrations/weaviate/tests/test_document_store.py @@ -4,7 +4,6 @@ import base64 import os -import random from typing import List from unittest.mock import MagicMock, patch @@ -14,18 +13,17 @@ from haystack.dataclasses.document import Document from haystack.document_stores.errors import DocumentStoreError from haystack.testing.document_store import ( - TEST_EMBEDDING_1, - TEST_EMBEDDING_2, CountDocumentsTest, DeleteDocumentsTest, FilterDocumentsTest, + FilterDocumentsTestWithDataframe, WriteDocumentsTest, + create_filterable_docs, ) from haystack.utils.auth import Secret from numpy import array as np_array from numpy import array_equal as np_array_equal from numpy import float32 as np_float32 -from pandas import DataFrame from weaviate.collections.classes.data import DataObject from weaviate.config import AdditionalConfig, ConnectionConfig, Proxies, Timeout from weaviate.embedded import ( @@ -50,7 +48,9 @@ def test_init_is_lazy(_mock_client): @pytest.mark.integration -class TestWeaviateDocumentStore(CountDocumentsTest, WriteDocumentsTest, DeleteDocumentsTest, FilterDocumentsTest): +class TestWeaviateDocumentStore( + CountDocumentsTest, WriteDocumentsTest, DeleteDocumentsTest, FilterDocumentsTest, FilterDocumentsTestWithDataframe +): @pytest.fixture def document_store(self, request) -> WeaviateDocumentStore: # Use a different index for each test so we can run them in parallel @@ -78,60 +78,24 @@ def filterable_docs(self) -> List[Document]: Weaviate forces RFC 3339 date strings. The original fixture uses ISO 8601 date strings. """ - documents = [] - for i in range(3): - documents.append( - Document( - content=f"A Foo Document {i}", - meta={ - "name": f"name_{i}", - "page": "100", - "chapter": "intro", - "number": 2, - "date": "1969-07-21T20:17:40Z", - }, - embedding=[random.random() for _ in range(768)], # noqa: S311 - ) - ) - documents.append( - Document( - content=f"A Bar Document {i}", - meta={ - "name": f"name_{i}", - "page": "123", - "chapter": "abstract", - "number": -2, - "date": "1972-12-11T19:54:58Z", - }, - embedding=[random.random() for _ in range(768)], # noqa: S311 - ) - ) - documents.append( - Document( - content=f"A Foobar Document {i}", - meta={ - "name": f"name_{i}", - "page": "90", - "chapter": "conclusion", - "number": -10, - "date": "1989-11-09T17:53:00Z", - }, - embedding=[random.random() for _ in range(768)], # noqa: S311 - ) - ) - documents.append( - Document( - content=f"Document {i} without embedding", - meta={"name": f"name_{i}", "no_embedding": True, "chapter": "conclusion"}, - ) - ) - documents.append(Document(dataframe=DataFrame([i]), meta={"name": f"table_doc_{i}"})) - documents.append( - Document(content=f"Doc {i} with zeros emb", meta={"name": "zeros_doc"}, embedding=TEST_EMBEDDING_1) - ) - documents.append( - Document(content=f"Doc {i} with ones emb", meta={"name": "ones_doc"}, embedding=TEST_EMBEDDING_2) - ) + documents = create_filterable_docs(include_dataframe_docs=False) + for i in range(len(documents)): + if date := documents[i].meta.get("date"): + documents[i].meta["date"] = f"{date}Z" + return documents + + @pytest.fixture + def filterable_docs_with_dataframe(self) -> List[Document]: + """ + This fixture has been copied from haystack/testing/document_store.py and modified to + use a different date format. + Weaviate forces RFC 3339 date strings. + The original fixture uses ISO 8601 date strings. + """ + documents = create_filterable_docs(include_dataframe_docs=True) + for i in range(len(documents)): + if date := documents[i].meta.get("date"): + documents[i].meta["date"] = f"{date}Z" return documents def assert_documents_are_equal(self, received: List[Document], expected: List[Document]): From 7e598d8237881a4902fce918abbfef1682d13169 Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Fri, 17 Jan 2025 14:12:55 +0100 Subject: [PATCH 191/229] feat: Add LangfuseConnector secure key management and serialization (#1287) * LangfuseConnector: add secret_key and public_key init params * Update tests * Linting * Add serde test * Lint * PR feedback * PR feedback --- .../connectors/langfuse/langfuse_connector.py | 50 ++++++++++- integrations/langfuse/tests/test_tracing.py | 84 +++++++++++++++++-- 2 files changed, 123 insertions(+), 11 deletions(-) diff --git a/integrations/langfuse/src/haystack_integrations/components/connectors/langfuse/langfuse_connector.py b/integrations/langfuse/src/haystack_integrations/components/connectors/langfuse/langfuse_connector.py index ff0a7c6ed..1762a92de 100644 --- a/integrations/langfuse/src/haystack_integrations/components/connectors/langfuse/langfuse_connector.py +++ b/integrations/langfuse/src/haystack_integrations/components/connectors/langfuse/langfuse_connector.py @@ -1,6 +1,7 @@ from typing import Any, Dict, Optional -from haystack import component, logging, tracing +from haystack import component, default_from_dict, default_to_dict, logging, tracing +from haystack.utils import Secret, deserialize_secrets_inplace from haystack_integrations.tracing.langfuse import LangfuseTracer from langfuse import Langfuse @@ -94,7 +95,13 @@ async def shutdown_event(): """ - def __init__(self, name: str, public: bool = False): + def __init__( + self, + name: str, + public: bool = False, + public_key: Optional[Secret] = Secret.from_env_var("LANGFUSE_PUBLIC_KEY"), # noqa: B008 + secret_key: Optional[Secret] = Secret.from_env_var("LANGFUSE_SECRET_KEY"), # noqa: B008 + ): """ Initialize the LangfuseConnector component. @@ -103,9 +110,21 @@ def __init__(self, name: str, public: bool = False): :param public: Whether the tracing data should be public or private. If set to `True`, the tracing data will be publicly accessible to anyone with the tracing URL. If set to `False`, the tracing data will be private and only accessible to the Langfuse account owner. The default is `False`. + :param public_key: The Langfuse public key. Defaults to reading from LANGFUSE_PUBLIC_KEY environment variable. + :param secret_key: The Langfuse secret key. Defaults to reading from LANGFUSE_SECRET_KEY environment variable. """ self.name = name - self.tracer = LangfuseTracer(tracer=Langfuse(), name=name, public=public) + self.public = public + self.secret_key = secret_key + self.public_key = public_key + self.tracer = LangfuseTracer( + tracer=Langfuse( + secret_key=secret_key.resolve_value() if secret_key else None, + public_key=public_key.resolve_value() if public_key else None, + ), + name=name, + public=public, + ) tracing.enable_tracing(self.tracer) @component.output_types(name=str, trace_url=str) @@ -126,3 +145,28 @@ def run(self, invocation_context: Optional[Dict[str, Any]] = None): invocation_context=invocation_context, ) return {"name": self.name, "trace_url": self.tracer.get_trace_url()} + + def to_dict(self) -> Dict[str, Any]: + """ + Serialize this component to a dictionary. + + :returns: The serialized component as a dictionary. + """ + return default_to_dict( + self, + name=self.name, + public=self.public, + secret_key=self.secret_key.to_dict() if self.secret_key else None, + public_key=self.public_key.to_dict() if self.public_key else None, + ) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "LangfuseConnector": + """ + Deserialize this component from a dictionary. + + :param data: The dictionary representation of this component. + :returns: The deserialized component instance. + """ + deserialize_secrets_inplace(data["init_parameters"], keys=["secret_key", "public_key"]) + return default_from_dict(cls, data) diff --git a/integrations/langfuse/tests/test_tracing.py b/integrations/langfuse/tests/test_tracing.py index 75c1b7a13..d4815fc9c 100644 --- a/integrations/langfuse/tests/test_tracing.py +++ b/integrations/langfuse/tests/test_tracing.py @@ -1,5 +1,4 @@ import os -import random import time from urllib.parse import urlparse @@ -9,6 +8,7 @@ from haystack.components.builders import ChatPromptBuilder from haystack.components.generators.chat import OpenAIChatGenerator from haystack.dataclasses import ChatMessage +from haystack.utils import Secret from requests.auth import HTTPBasicAuth from haystack_integrations.components.connectors.langfuse import LangfuseConnector @@ -19,6 +19,36 @@ os.environ["HAYSTACK_CONTENT_TRACING_ENABLED"] = "true" +@pytest.fixture +def pipeline_with_env_vars(llm_class, expected_trace): + """Pipeline factory using environment variables for Langfuse authentication""" + pipe = Pipeline() + pipe.add_component("tracer", LangfuseConnector(name=f"Chat example - {expected_trace}", public=True)) + pipe.add_component("prompt_builder", ChatPromptBuilder()) + pipe.add_component("llm", llm_class()) + pipe.connect("prompt_builder.prompt", "llm.messages") + return pipe + + +@pytest.fixture +def pipeline_with_secrets(llm_class, expected_trace): + """Pipeline factory using Secret objects for Langfuse authentication""" + pipe = Pipeline() + pipe.add_component( + "tracer", + LangfuseConnector( + name=f"Chat example - {expected_trace}", + public=True, + secret_key=Secret.from_env_var("LANGFUSE_SECRET_KEY"), + public_key=Secret.from_env_var("LANGFUSE_PUBLIC_KEY"), + ), + ) + pipe.add_component("prompt_builder", ChatPromptBuilder()) + pipe.add_component("llm", llm_class()) + pipe.connect("prompt_builder.prompt", "llm.messages") + return pipe + + @pytest.mark.integration @pytest.mark.parametrize( "llm_class, env_var, expected_trace", @@ -28,16 +58,12 @@ (CohereChatGenerator, "COHERE_API_KEY", "Cohere"), ], ) -def test_tracing_integration(llm_class, env_var, expected_trace): +@pytest.mark.parametrize("pipeline_fixture", ["pipeline_with_env_vars", "pipeline_with_secrets"]) +def test_tracing_integration(llm_class, env_var, expected_trace, pipeline_fixture, request): if not all([os.environ.get("LANGFUSE_SECRET_KEY"), os.environ.get("LANGFUSE_PUBLIC_KEY"), os.environ.get(env_var)]): pytest.skip(f"Missing required environment variables: LANGFUSE_SECRET_KEY, LANGFUSE_PUBLIC_KEY, or {env_var}") - pipe = Pipeline() - pipe.add_component("tracer", LangfuseConnector(name=f"Chat example - {expected_trace}", public=True)) - pipe.add_component("prompt_builder", ChatPromptBuilder()) - pipe.add_component("llm", llm_class()) - pipe.connect("prompt_builder.prompt", "llm.messages") - + pipe = request.getfixturevalue(pipeline_fixture) messages = [ ChatMessage.from_system("Always respond in German even if some input data is in other languages."), ChatMessage.from_user("Tell me about {{location}}"), @@ -77,3 +103,45 @@ def test_tracing_integration(llm_class, env_var, expected_trace): # check if the trace contains the expected user_id assert "user_42" in str(res.content) break + + +def test_pipeline_serialization(monkeypatch): + """Test that a pipeline with secrets can be properly serialized and deserialized""" + + # Set test env vars + monkeypatch.setenv("LANGFUSE_SECRET_KEY", "secret") + monkeypatch.setenv("LANGFUSE_PUBLIC_KEY", "public") + monkeypatch.setenv("OPENAI_API_KEY", "openai_api_key") + + # Create pipeline with OpenAI LLM + pipe = Pipeline() + pipe.add_component( + "tracer", + LangfuseConnector( + name="Chat example - OpenAI", + public=True, + secret_key=Secret.from_env_var("LANGFUSE_SECRET_KEY"), + public_key=Secret.from_env_var("LANGFUSE_PUBLIC_KEY"), + ), + ) + pipe.add_component("prompt_builder", ChatPromptBuilder()) + pipe.add_component("llm", OpenAIChatGenerator()) + pipe.connect("prompt_builder.prompt", "llm.messages") + + # Serialize + serialized = pipe.to_dict() + + # Check serialized secrets + tracer_params = serialized["components"]["tracer"]["init_parameters"] + assert isinstance(tracer_params["secret_key"], dict) + assert tracer_params["secret_key"]["type"] == "env_var" + assert tracer_params["secret_key"]["env_vars"] == ["LANGFUSE_SECRET_KEY"] + assert isinstance(tracer_params["public_key"], dict) + assert tracer_params["public_key"]["type"] == "env_var" + assert tracer_params["public_key"]["env_vars"] == ["LANGFUSE_PUBLIC_KEY"] + + # Deserialize + new_pipe = Pipeline.from_dict(serialized) + + # Verify pipeline is the same + assert new_pipe == pipe From 4e2241abad2719868efa719aef6fcff58d7a86ca Mon Sep 17 00:00:00 2001 From: HaystackBot Date: Fri, 17 Jan 2025 13:15:34 +0000 Subject: [PATCH 192/229] Update the changelog --- integrations/langfuse/CHANGELOG.md | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/integrations/langfuse/CHANGELOG.md b/integrations/langfuse/CHANGELOG.md index 9f857966b..3eb8ddba6 100644 --- a/integrations/langfuse/CHANGELOG.md +++ b/integrations/langfuse/CHANGELOG.md @@ -1,5 +1,12 @@ # Changelog +## [integrations/langfuse-v0.6.4] - 2025-01-17 + +### 🚀 Features + +- Add LangfuseConnector secure key management and serialization (#1287) + + ## [integrations/langfuse-v0.6.3] - 2025-01-15 ### 🌀 Miscellaneous From 42ac5ca62fa2bbac28c7437beee8b856cb26ee9c Mon Sep 17 00:00:00 2001 From: Stefano Fiorucci Date: Fri, 17 Jan 2025 15:04:53 +0100 Subject: [PATCH 193/229] Anthropic tools + refactoring (#1300) --- .../anthropic/example/prompt_caching.py | 2 +- integrations/anthropic/pyproject.toml | 3 +- .../anthropic/chat/chat_generator.py | 526 ++++++---- .../anthropic/chat/vertex_chat_generator.py | 12 +- .../anthropic/tests/test_chat_generator.py | 935 +++++++++++++++--- .../tests/test_vertex_chat_generator.py | 2 + 6 files changed, 1142 insertions(+), 338 deletions(-) diff --git a/integrations/anthropic/example/prompt_caching.py b/integrations/anthropic/example/prompt_caching.py index d8cc0f0e8..4f2ee58fb 100644 --- a/integrations/anthropic/example/prompt_caching.py +++ b/integrations/anthropic/example/prompt_caching.py @@ -59,7 +59,7 @@ def stream_callback(chunk: StreamingChunk) -> None: ) if ENABLE_PROMPT_CACHING: - system_message.meta["cache_control"] = {"type": "ephemeral"} + system_message._meta["cache_control"] = {"type": "ephemeral"} questions = [ "What's this paper about?", diff --git a/integrations/anthropic/pyproject.toml b/integrations/anthropic/pyproject.toml index 21e23fbb4..401a6defd 100644 --- a/integrations/anthropic/pyproject.toml +++ b/integrations/anthropic/pyproject.toml @@ -23,7 +23,7 @@ classifiers = [ "Programming Language :: Python :: Implementation :: CPython", "Programming Language :: Python :: Implementation :: PyPy", ] -dependencies = ["haystack-ai", "anthropic"] +dependencies = ["haystack-ai>=2.9.0", "anthropic"] [project.urls] Documentation = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/anthropic#readme" @@ -48,6 +48,7 @@ dependencies = [ "pytest", "pytest-rerunfailures", "haystack-pydoc-tools", + "jsonschema", # needed for Tool ] [tool.hatch.envs.default.scripts] test = "pytest {args:tests}" diff --git a/integrations/anthropic/src/haystack_integrations/components/generators/anthropic/chat/chat_generator.py b/integrations/anthropic/src/haystack_integrations/components/generators/anthropic/chat/chat_generator.py index 56a740146..752bfcb92 100644 --- a/integrations/anthropic/src/haystack_integrations/components/generators/anthropic/chat/chat_generator.py +++ b/integrations/anthropic/src/haystack_integrations/components/generators/anthropic/chat/chat_generator.py @@ -1,63 +1,152 @@ import json -from typing import Any, Callable, ClassVar, Dict, List, Optional, Union +from typing import Any, Callable, ClassVar, Dict, List, Optional, Tuple from haystack import component, default_from_dict, default_to_dict, logging -from haystack.dataclasses import ChatMessage, ChatRole, StreamingChunk +from haystack.dataclasses import ChatMessage, ChatRole, StreamingChunk, ToolCall, ToolCallResult +from haystack.tools import Tool, _check_duplicate_tool_names, deserialize_tools_inplace from haystack.utils import Secret, deserialize_callable, deserialize_secrets_inplace, serialize_callable from anthropic import Anthropic, Stream -from anthropic.types import ( - ContentBlockDeltaEvent, - Message, - MessageDeltaEvent, - MessageStartEvent, - MessageStreamEvent, - TextBlock, - TextDelta, - ToolUseBlock, -) logger = logging.getLogger(__name__) +def _update_anthropic_message_with_tool_call_results( + tool_call_results: List[ToolCallResult], anthropic_msg: Dict[str, Any] +) -> None: + """ + Update an Anthropic message with tool call results. + + :param tool_call_results: The list of ToolCallResults to update the message with. + :param anthropic_msg: The Anthropic message to update. + """ + if "content" not in anthropic_msg: + anthropic_msg["content"] = [] + + for tool_call_result in tool_call_results: + if tool_call_result.origin.id is None: + msg = "`ToolCall` must have a non-null `id` attribute to be used with Anthropic." + raise ValueError(msg) + anthropic_msg["content"].append( + { + "type": "tool_result", + "tool_use_id": tool_call_result.origin.id, + "content": [{"type": "text", "text": tool_call_result.result}], + "is_error": tool_call_result.error, + } + ) + + +def _convert_tool_calls_to_anthropic_format(tool_calls: List[ToolCall]) -> List[Dict[str, Any]]: + """ + Convert a list of tool calls to the format expected by Anthropic Chat API. + + :param tool_calls: The list of ToolCalls to convert. + :return: A list of dictionaries in the format expected by Anthropic API. + """ + anthropic_tool_calls = [] + for tc in tool_calls: + if tc.id is None: + msg = "`ToolCall` must have a non-null `id` attribute to be used with Anthropic." + raise ValueError(msg) + anthropic_tool_calls.append( + { + "type": "tool_use", + "id": tc.id, + "name": tc.tool_name, + "input": tc.arguments, + } + ) + return anthropic_tool_calls + + +def _convert_messages_to_anthropic_format( + messages: List[ChatMessage], +) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]: + """ + Convert a list of messages to the format expected by Anthropic Chat API. + + :param messages: The list of ChatMessages to convert. + :return: A tuple of two lists: + - A list of system message dictionaries in the format expected by Anthropic API. + - A list of non-system message dictionaries in the format expected by Anthropic API. + """ + + anthropic_system_messages = [] + anthropic_non_system_messages = [] + + i = 0 + while i < len(messages): + message = messages[i] + + # allow passing cache_control + cache_control = {"cache_control": message.meta.get("cache_control")} if "cache_control" in message.meta else {} + + # system messages have special format requirements for Anthropic API + # they can have only type and text fields, and they need to be passed separately + # to the Anthropic API endpoint + if message.is_from(ChatRole.SYSTEM): + anthropic_system_messages.append({"type": "text", "text": message.text, **cache_control}) + i += 1 + continue + + anthropic_msg: Dict[str, Any] = {"role": message._role.value, "content": [], **cache_control} + + if message.texts and message.texts[0]: + anthropic_msg["content"].append({"type": "text", "text": message.texts[0]}) + if message.tool_calls: + anthropic_msg["content"] += _convert_tool_calls_to_anthropic_format(message.tool_calls) + + if message.tool_call_results: + results = message.tool_call_results.copy() + # Handle consecutive tool call results + while (i + 1) < len(messages) and messages[i + 1].tool_call_results: + i += 1 + results.extend(messages[i].tool_call_results) + + _update_anthropic_message_with_tool_call_results(results, anthropic_msg) + anthropic_msg["role"] = "user" + + if not anthropic_msg["content"]: + msg = "A `ChatMessage` must contain at least one `TextContent`, `ToolCall`, or `ToolCallResult`." + raise ValueError(msg) + + anthropic_non_system_messages.append(anthropic_msg) + i += 1 + + return anthropic_system_messages, anthropic_non_system_messages + + @component class AnthropicChatGenerator: """ - Enables text generation using Anthropic state-of-the-art Claude 3 family of large language models (LLMs) through - the Anthropic messaging API. + Completes chats using Anthropic's large language models (LLMs). - It supports models like `claude-3-5-sonnet`, `claude-3-opus`, `claude-3-sonnet`, and `claude-3-haiku`, - accessed through the [`/v1/messages`](https://docs.anthropic.com/en/api/messages) API endpoint. + It uses [ChatMessage](https://docs.haystack.deepset.ai/docs/data-classes#chatmessage) + format in input and output. - Users can pass any text generation parameters valid for the Anthropic messaging API directly to this component - via the `generation_kwargs` parameter in `__init__` or the `generation_kwargs` parameter in the `run` method. + You can customize how the text is generated by passing parameters to the + Anthropic API. Use the `**generation_kwargs` argument when you initialize + the component or when you run it. Any parameter that works with + `anthropic.Message.create` will work here too. - For more details on the parameters supported by the Anthropic API, refer to the - Anthropic Message API [documentation](https://docs.anthropic.com/en/api/messages). + For details on Anthropic API parameters, see + [Anthropic documentation](https://docs.anthropic.com/en/api/messages). + Usage example: ```python - from haystack_integrations.components.generators.anthropic import AnthropicChatGenerator - from haystack.dataclasses import ChatMessage - - messages = [ChatMessage.from_user("What's Natural Language Processing?")] - client = AnthropicChatGenerator(model="claude-3-5-sonnet-20240620") - response = client.run(messages) - print(response) - - >> {'replies': [ChatMessage(content='Natural Language Processing (NLP) is a field of artificial intelligence that - >> focuses on enabling computers to understand, interpret, and generate human language. It involves developing - >> techniques and algorithms to analyze and process text or speech data, allowing machines to comprehend and - >> communicate in natural languages like English, Spanish, or Chinese.', role=, - >> name=None, meta={'model': 'claude-3-5-sonnet-20240620', 'index': 0, 'finish_reason': 'end_turn', - >> 'usage': {'input_tokens': 15, 'output_tokens': 64}})]} - ``` - - For more details on supported models and their capabilities, refer to the Anthropic - [documentation](https://docs.anthropic.com/claude/docs/intro-to-claude). - - Note: We only support text input/output modalities, and - image [modality](https://docs.anthropic.com/en/docs/build-with-claude/vision) is not supported in - this version of AnthropicChatGenerator. + from haystack_experimental.components.generators.anthropic import AnthropicChatGenerator + from haystack_experimental.dataclasses import ChatMessage + + generator = AnthropicChatGenerator(model="claude-3-5-sonnet-20240620", + generation_kwargs={ + "max_tokens": 1000, + "temperature": 0.7, + }) + + messages = [ChatMessage.from_system("You are a helpful, respectful and honest assistant"), + ChatMessage.from_user("What's Natural Language Processing?")] + print(generator.run(messages=messages)) """ # The parameters that can be passed to the Anthropic API https://docs.anthropic.com/claude/reference/messages_post @@ -81,6 +170,7 @@ def __init__( streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, generation_kwargs: Optional[Dict[str, Any]] = None, ignore_tools_thinking_messages: bool = True, + tools: Optional[List[Tool]] = None, ): """ Creates an instance of AnthropicChatGenerator. @@ -107,13 +197,18 @@ def __init__( `ignore_tools_thinking_messages` is `True`, the generator will drop so-called thinking messages when tool use is detected. See the Anthropic [tools](https://docs.anthropic.com/en/docs/tool-use#chain-of-thought-tool-use) for more details. + :param tools: A list of Tool objects that the model can use. Each tool should have a unique name. + """ + _check_duplicate_tool_names(tools) + self.api_key = api_key self.model = model self.generation_kwargs = generation_kwargs or {} self.streaming_callback = streaming_callback self.client = Anthropic(api_key=self.api_key.resolve_value()) self.ignore_tools_thinking_messages = ignore_tools_thinking_messages + self.tools = tools def _get_telemetry_data(self) -> Dict[str, Any]: """ @@ -129,6 +224,7 @@ def to_dict(self) -> Dict[str, Any]: The serialized component as a dictionary. """ callback_name = serialize_callable(self.streaming_callback) if self.streaming_callback else None + serialized_tools = [tool.to_dict() for tool in self.tools] if self.tools else None return default_to_dict( self, model=self.model, @@ -136,6 +232,7 @@ def to_dict(self) -> Dict[str, Any]: generation_kwargs=self.generation_kwargs, api_key=self.api_key.to_dict(), ignore_tools_thinking_messages=self.ignore_tools_thinking_messages, + tools=serialized_tools, ) @classmethod @@ -148,180 +245,130 @@ def from_dict(cls, data: Dict[str, Any]) -> "AnthropicChatGenerator": The deserialized component instance. """ deserialize_secrets_inplace(data["init_parameters"], keys=["api_key"]) + deserialize_tools_inplace(data["init_parameters"], key="tools") init_params = data.get("init_parameters", {}) serialized_callback_handler = init_params.get("streaming_callback") if serialized_callback_handler: data["init_parameters"]["streaming_callback"] = deserialize_callable(serialized_callback_handler) + return default_from_dict(cls, data) - @component.output_types(replies=List[ChatMessage]) - def run(self, messages: List[ChatMessage], generation_kwargs: Optional[Dict[str, Any]] = None): + @staticmethod + def _get_openai_compatible_usage(response_dict: dict) -> dict: """ - Invoke the text generation inference based on the provided messages and generation parameters. - - :param messages: A list of ChatMessage instances representing the input messages. - :param generation_kwargs: Additional keyword arguments for text generation. These parameters will - potentially override the parameters passed in the `__init__` method. - For more details on the parameters supported by the Anthropic API, refer to the - Anthropic [documentation](https://www.anthropic.com/python-library). - - :returns: - - `replies`: A list of ChatMessage instances representing the generated responses. + Converts Anthropic usage metadata to OpenAI compatible format. """ + usage = response_dict.get("usage", {}) + if usage: + if "input_tokens" in usage: + usage["prompt_tokens"] = usage.pop("input_tokens") + if "output_tokens" in usage: + usage["completion_tokens"] = usage.pop("output_tokens") - # update generation kwargs by merging with the generation kwargs passed to the run method - generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})} - filtered_generation_kwargs = {k: v for k, v in generation_kwargs.items() if k in self.ALLOWED_PARAMS} - disallowed_params = set(generation_kwargs) - set(self.ALLOWED_PARAMS) - if disallowed_params: - logger.warning( - f"Model parameters {disallowed_params} are not allowed and will be ignored. " - f"Allowed parameters are {self.ALLOWED_PARAMS}." - ) - system_messages: List[ChatMessage] = [msg for msg in messages if msg.is_from(ChatRole.SYSTEM)] - non_system_messages: List[ChatMessage] = [msg for msg in messages if not msg.is_from(ChatRole.SYSTEM)] - system_messages_formatted: List[Dict[str, Any]] = ( - self._convert_to_anthropic_format(system_messages) if system_messages else [] - ) - messages_formatted: List[Dict[str, Any]] = ( - self._convert_to_anthropic_format(non_system_messages) if non_system_messages else [] - ) + return usage - extra_headers = filtered_generation_kwargs.get("extra_headers", {}) - prompt_caching_on = "anthropic-beta" in extra_headers and "prompt-caching" in extra_headers["anthropic-beta"] - has_cached_messages = any("cache_control" in m for m in system_messages_formatted) or any( - "cache_control" in m for m in messages_formatted - ) - if has_cached_messages and not prompt_caching_on: - # this avoids Anthropic errors when prompt caching is not enabled - # but user requested individual messages to be cached - logger.warn( - "Prompt caching is not enabled but you requested individual messages to be cached. " - "Messages will be sent to the API without prompt caching." - ) - system_messages_formatted = list(map(self._remove_cache_control, system_messages_formatted)) - messages_formatted = list(map(self._remove_cache_control, messages_formatted)) - - response: Union[Message, Stream[MessageStreamEvent]] = self.client.messages.create( - max_tokens=filtered_generation_kwargs.pop("max_tokens", 512), - system=system_messages_formatted or filtered_generation_kwargs.pop("system", ""), - model=self.model, - messages=messages_formatted, - stream=self.streaming_callback is not None, - **filtered_generation_kwargs, - ) - - completions: List[ChatMessage] = [] - # if streaming is enabled, the response is a Stream[MessageStreamEvent] - if isinstance(response, Stream): - chunks: List[StreamingChunk] = [] - stream_event, delta, start_event = None, None, None - for stream_event in response: - if isinstance(stream_event, MessageStartEvent): - # capture start message to count input tokens - start_event = stream_event - if isinstance(stream_event, ContentBlockDeltaEvent): - chunk_delta: StreamingChunk = self._build_chunk(stream_event.delta) - chunks.append(chunk_delta) - if self.streaming_callback: - self.streaming_callback(chunk_delta) # invoke callback with the chunk_delta - if isinstance(stream_event, MessageDeltaEvent): - # capture stop reason and stop sequence - delta = stream_event - completions = [self._connect_chunks(chunks, start_event, delta)] - - # if streaming is disabled, the response is an Anthropic Message - elif isinstance(response, Message): - has_tools_msgs = any(isinstance(content_block, ToolUseBlock) for content_block in response.content) - if has_tools_msgs and self.ignore_tools_thinking_messages: - response.content = [block for block in response.content if isinstance(block, ToolUseBlock)] - completions = [self._build_message(content_block, response) for content_block in response.content] - - # rename the meta key to be inline with OpenAI meta output keys - for response in completions: - if response.meta is not None and "usage" in response.meta: - response.meta["usage"]["prompt_tokens"] = response.meta["usage"].pop("input_tokens") - response.meta["usage"]["completion_tokens"] = response.meta["usage"].pop("output_tokens") - - return {"replies": completions} - - def _build_message(self, content_block: Union[TextBlock, ToolUseBlock], message: Message) -> ChatMessage: + def _convert_chat_completion_to_chat_message( + self, anthropic_response: Any, ignore_tools_thinking_messages: bool + ) -> ChatMessage: """ - Converts the non-streaming Anthropic Message to a ChatMessage. - :param content_block: The content block of the message. - :param message: The non-streaming Anthropic Message. - :returns: The ChatMessage. + Converts the response from the Anthropic API to a ChatMessage. """ - if isinstance(content_block, TextBlock): - chat_message = ChatMessage.from_assistant(content_block.text) - else: - chat_message = ChatMessage.from_assistant(json.dumps(content_block.model_dump(mode="json"))) - chat_message.meta.update( + tool_calls = [ + ToolCall(tool_name=block.name, arguments=block.input, id=block.id) + for block in anthropic_response.content + if block.type == "tool_use" + ] + + # Extract and join text blocks, respecting ignore_tools_thinking_messages + text = "" + if not (ignore_tools_thinking_messages and tool_calls): + text = " ".join(block.text for block in anthropic_response.content if block.type == "text") + + message = ChatMessage.from_assistant(text=text, tool_calls=tool_calls) + + # Dump the chat completion to a dict + response_dict = anthropic_response.model_dump() + usage = self._get_openai_compatible_usage(response_dict) + message._meta.update( { - "model": message.model, + "model": response_dict.get("model", None), "index": 0, - "finish_reason": message.stop_reason, - "usage": dict(message.usage or {}), + "finish_reason": response_dict.get("stop_reason", None), + "usage": usage, } ) - return chat_message + return message - def _convert_to_anthropic_format(self, messages: List[ChatMessage]) -> List[Dict[str, Any]]: + def _convert_anthropic_chunk_to_streaming_chunk(self, chunk: Any) -> StreamingChunk: """ - Converts the list of ChatMessage to the list of messages in the format expected by the Anthropic API. - :param messages: The list of ChatMessage. - :returns: The list of messages in the format expected by the Anthropic API. + Converts an Anthropic StreamEvent to a StreamingChunk. """ - anthropic_formatted_messages = [] - for m in messages: - message_dict = m.to_dict() - formatted_message = {} - - # legacy format - if "role" in message_dict and "content" in message_dict: - formatted_message = {k: v for k, v in message_dict.items() if k in {"role", "content"} and v} - # new format - elif "_role" in message_dict and "_content" in message_dict: - formatted_message = {"role": m.role.value, "content": m.text} - - if m.is_from(ChatRole.SYSTEM): - # system messages are treated differently and MUST be in the format expected by the Anthropic API - # remove role and content from the message dict, add type and text - formatted_message.pop("role") - formatted_message["type"] = "text" - formatted_message["text"] = formatted_message.pop("content") - formatted_message.update(m.meta or {}) - anthropic_formatted_messages.append(formatted_message) - return anthropic_formatted_messages - - def _connect_chunks( - self, chunks: List[StreamingChunk], message_start: MessageStartEvent, delta: MessageDeltaEvent + content = "" + if chunk.type == "content_block_delta" and chunk.delta.type == "text_delta": + content = chunk.delta.text + + return StreamingChunk(content=content, meta=chunk.model_dump()) + + def _convert_streaming_chunks_to_chat_message( + self, chunks: List[StreamingChunk], model: Optional[str] = None ) -> ChatMessage: """ - Connects the streaming chunks into a single ChatMessage. - :param chunks: The list of all chunks returned by the Anthropic API. - :param message_start: The MessageStartEvent. - :param delta: The MessageDeltaEvent. - :returns: The complete ChatMessage. + Converts a list of StreamingChunks to a ChatMessage. """ - complete_response = ChatMessage.from_assistant("".join([chunk.content for chunk in chunks])) - complete_response.meta.update( + full_content = "" + tool_calls = [] + current_tool_call: Optional[Dict[str, Any]] = {} + + # loop through chunks and call the appropriate handler + for chunk in chunks: + chunk_type = chunk.meta.get("type") + if chunk_type == "content_block_start": + if chunk.meta.get("content_block", {}).get("type") == "tool_use": + delta_block = chunk.meta.get("content_block") + current_tool_call = { + "id": delta_block.get("id"), + "name": delta_block.get("name"), + "arguments": "", + } + elif chunk_type == "content_block_delta": + delta = chunk.meta.get("delta", {}) + if delta.get("type") == "text_delta": + full_content += delta.get("text", "") + elif delta.get("type") == "input_json_delta" and current_tool_call: + current_tool_call["arguments"] += delta.get("partial_json", "") + elif chunk_type == "message_delta": + if chunk.meta.get("delta", {}).get("stop_reason") == "tool_use" and current_tool_call: + try: + # arguments is a string, convert to json + tool_calls.append( + ToolCall( + id=current_tool_call.get("id"), + tool_name=str(current_tool_call.get("name")), + arguments=json.loads(current_tool_call.get("arguments", {})), + ) + ) + except json.JSONDecodeError: + logger.warning( + "Anthropic returned a malformed JSON string for tool call arguments. " + f"This tool call will be skipped. Arguments: {current_tool_call.get('arguments', '')}", + ) + current_tool_call = None + + message = ChatMessage.from_assistant(full_content, tool_calls=tool_calls) + + # Update meta information + last_chunk_meta = chunks[-1].meta + usage = self._get_openai_compatible_usage(last_chunk_meta) + message._meta.update( { - "model": self.model, + "model": model, "index": 0, - "finish_reason": delta.delta.stop_reason if delta else "end_turn", - "usage": {**dict(message_start.message.usage, **dict(delta.usage))} if delta and message_start else {}, + "finish_reason": last_chunk_meta.get("delta", {}).get("stop_reason", None), + "usage": usage, } ) - return complete_response - def _build_chunk(self, delta: TextDelta) -> StreamingChunk: - """ - Converts the ContentBlockDeltaEvent to a StreamingChunk. - :param delta: The ContentBlockDeltaEvent. - :returns: The StreamingChunk. - """ - return StreamingChunk(content=delta.text) + return message def _remove_cache_control(self, message: Dict[str, Any]) -> Dict[str, Any]: """ @@ -330,3 +377,104 @@ def _remove_cache_control(self, message: Dict[str, Any]) -> Dict[str, Any]: :returns: The message with the cache_control key removed. """ return {k: v for k, v in message.items() if k != "cache_control"} + + @component.output_types(replies=List[ChatMessage]) + def run( + self, + messages: List[ChatMessage], + streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, + generation_kwargs: Optional[Dict[str, Any]] = None, + tools: Optional[List[Tool]] = None, + ): + """ + Invokes the Anthropic API with the given messages and generation kwargs. + + :param messages: A list of ChatMessage instances representing the input messages. + :param streaming_callback: A callback function that is called when a new token is received from the stream. + :param generation_kwargs: Optional arguments to pass to the Anthropic generation endpoint. + :param tools: A list of tools for which the model can prepare calls. If set, it will override + the `tools` parameter set during component initialization. + :returns: A dictionary with the following keys: + - `replies`: The responses from the model + """ + # update generation kwargs by merging with the generation kwargs passed to the run method + generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})} + disallowed_params = set(generation_kwargs) - set(self.ALLOWED_PARAMS) + if disallowed_params: + logger.warning( + "Model parameters %s are not allowed and will be ignored. Allowed parameters are %s.", + disallowed_params, + self.ALLOWED_PARAMS, + ) + generation_kwargs = {k: v for k, v in generation_kwargs.items() if k in self.ALLOWED_PARAMS} + + system_messages, non_system_messages = _convert_messages_to_anthropic_format(messages) + + # prompt caching + extra_headers = generation_kwargs.get("extra_headers", {}) + prompt_caching_on = "anthropic-beta" in extra_headers and "prompt-caching" in extra_headers["anthropic-beta"] + has_cached_messages = any("cache_control" in m for m in system_messages) or any( + "cache_control" in m for m in non_system_messages + ) + if has_cached_messages and not prompt_caching_on: + # this avoids Anthropic errors when prompt caching is not enabled + # but user requested individual messages to be cached + logger.warn( + "Prompt caching is not enabled but you requested individual messages to be cached. " + "Messages will be sent to the API without prompt caching." + ) + system_messages = list(map(self._remove_cache_control, system_messages)) + non_system_messages = list(map(self._remove_cache_control, non_system_messages)) + + # tools management + tools = tools or self.tools + _check_duplicate_tool_names(tools) + anthropic_tools = ( + [ + { + "name": tool.name, + "description": tool.description, + "input_schema": tool.parameters, + } + for tool in tools + ] + if tools + else [] + ) + + streaming_callback = streaming_callback or self.streaming_callback + + response = self.client.messages.create( + model=self.model, + messages=non_system_messages, + system=system_messages, + tools=anthropic_tools, + stream=streaming_callback is not None, + max_tokens=generation_kwargs.pop("max_tokens", 1024), + **generation_kwargs, + ) + + if isinstance(response, Stream): + chunks: List[StreamingChunk] = [] + model: Optional[str] = None + for chunk in response: + if chunk.type == "message_start": + model = chunk.message.model + elif chunk.type in [ + "content_block_start", + "content_block_delta", + "message_delta", + ]: + streaming_chunk = self._convert_anthropic_chunk_to_streaming_chunk(chunk) + chunks.append(streaming_chunk) + if streaming_callback: + streaming_callback(streaming_chunk) + + completion = self._convert_streaming_chunks_to_chat_message(chunks, model) + return {"replies": [completion]} + else: + return { + "replies": [ + self._convert_chat_completion_to_chat_message(response, self.ignore_tools_thinking_messages) + ] + } diff --git a/integrations/anthropic/src/haystack_integrations/components/generators/anthropic/chat/vertex_chat_generator.py b/integrations/anthropic/src/haystack_integrations/components/generators/anthropic/chat/vertex_chat_generator.py index 4ece944cd..3004fcfb7 100644 --- a/integrations/anthropic/src/haystack_integrations/components/generators/anthropic/chat/vertex_chat_generator.py +++ b/integrations/anthropic/src/haystack_integrations/components/generators/anthropic/chat/vertex_chat_generator.py @@ -1,8 +1,9 @@ import os -from typing import Any, Callable, Dict, Optional +from typing import Any, Callable, Dict, List, Optional from haystack import component, default_from_dict, default_to_dict, logging from haystack.dataclasses import StreamingChunk +from haystack.tools import Tool, _check_duplicate_tool_names, deserialize_tools_inplace from haystack.utils import deserialize_callable, serialize_callable from anthropic import AnthropicVertex @@ -65,6 +66,7 @@ def __init__( streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, generation_kwargs: Optional[Dict[str, Any]] = None, ignore_tools_thinking_messages: bool = True, + tools: Optional[List[Tool]] = None, ): """ Creates an instance of AnthropicVertexChatGenerator. @@ -92,7 +94,9 @@ def __init__( `ignore_tools_thinking_messages` is `True`, the generator will drop so-called thinking messages when tool use is detected. See the Anthropic [tools](https://docs.anthropic.com/en/docs/tool-use#chain-of-thought-tool-use) for more details. + :param tools: A list of Tool objects that the model can use. Each tool should have a unique name. """ + _check_duplicate_tool_names(tools) self.region = region or os.environ.get("REGION") self.project_id = project_id or os.environ.get("PROJECT_ID") self.model = model @@ -100,6 +104,7 @@ def __init__( self.streaming_callback = streaming_callback self.client = AnthropicVertex(region=self.region, project_id=self.project_id) self.ignore_tools_thinking_messages = ignore_tools_thinking_messages + self.tools = tools def to_dict(self) -> Dict[str, Any]: """ @@ -109,6 +114,8 @@ def to_dict(self) -> Dict[str, Any]: The serialized component as a dictionary. """ callback_name = serialize_callable(self.streaming_callback) if self.streaming_callback else None + serialized_tools = [tool.to_dict() for tool in self.tools] if self.tools else None + return default_to_dict( self, region=self.region, @@ -117,6 +124,7 @@ def to_dict(self) -> Dict[str, Any]: streaming_callback=callback_name, generation_kwargs=self.generation_kwargs, ignore_tools_thinking_messages=self.ignore_tools_thinking_messages, + tools=serialized_tools, ) @classmethod @@ -128,8 +136,10 @@ def from_dict(cls, data: Dict[str, Any]) -> "AnthropicVertexChatGenerator": :returns: The deserialized component instance. """ + deserialize_tools_inplace(data["init_parameters"], key="tools") init_params = data.get("init_parameters", {}) serialized_callback_handler = init_params.get("streaming_callback") if serialized_callback_handler: data["init_parameters"]["streaming_callback"] = deserialize_callable(serialized_callback_handler) + return default_from_dict(cls, data) diff --git a/integrations/anthropic/tests/test_chat_generator.py b/integrations/anthropic/tests/test_chat_generator.py index d46ee624d..30c5986bc 100644 --- a/integrations/anthropic/tests/test_chat_generator.py +++ b/integrations/anthropic/tests/test_chat_generator.py @@ -1,117 +1,254 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 import json +import logging import os +from unittest.mock import patch import anthropic import pytest +from anthropic.types import ( + ContentBlockDeltaEvent, + ContentBlockStartEvent, + Message, + MessageStartEvent, + TextBlockParam, + TextDelta, +) +from haystack import Pipeline from haystack.components.generators.utils import print_streaming_chunk -from haystack.dataclasses import ChatMessage, ChatRole, StreamingChunk +from haystack.dataclasses import ChatMessage, ChatRole, StreamingChunk, ToolCall +from haystack.tools import Tool from haystack.utils.auth import Secret -from haystack_integrations.components.generators.anthropic import AnthropicChatGenerator +from haystack_integrations.components.generators.anthropic.chat.chat_generator import ( + AnthropicChatGenerator, + _convert_messages_to_anthropic_format, +) + + +@pytest.fixture +def tools(): + tool_parameters = {"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]} + tool = Tool( + name="weather", + description="useful to determine the weather in a given location", + parameters=tool_parameters, + function=lambda x: x, + ) + + return [tool] @pytest.fixture def chat_messages(): return [ - ChatMessage.from_system("\\nYou are a helpful assistant, be super brief in your responses."), - ChatMessage.from_user("What's the capital of France?"), + ChatMessage.from_user("What's the capital of France"), ] +@pytest.fixture +def mock_anthropic_completion(): + with patch("anthropic.resources.messages.Messages.create") as mock_anthropic: + completion = Message( + id="foo", + type="message", + model="claude-3-5-sonnet-20240620", + role="assistant", + content=[TextBlockParam(type="text", text="Hello! I'm Claude.")], + stop_reason="end_turn", + usage={"input_tokens": 10, "output_tokens": 20}, + ) + mock_anthropic.return_value = completion + yield mock_anthropic + + class TestAnthropicChatGenerator: def test_init_default(self, monkeypatch): + """ + Test the default initialization of the AnthropicChatGenerator component. + """ monkeypatch.setenv("ANTHROPIC_API_KEY", "test-api-key") component = AnthropicChatGenerator() assert component.client.api_key == "test-api-key" assert component.model == "claude-3-5-sonnet-20240620" assert component.streaming_callback is None assert not component.generation_kwargs - assert component.ignore_tools_thinking_messages + assert component.tools is None def test_init_fail_wo_api_key(self, monkeypatch): + """ + Test that the AnthropicChatGenerator component fails to initialize without an API key. + """ monkeypatch.delenv("ANTHROPIC_API_KEY", raising=False) - with pytest.raises(ValueError, match="None of the .* environment variables are set"): + with pytest.raises(ValueError): AnthropicChatGenerator() - def test_init_with_parameters(self): + def test_init_fail_with_duplicate_tool_names(self, monkeypatch, tools): + """ + Test that the AnthropicChatGenerator component fails to initialize with duplicate tool names. + """ + monkeypatch.setenv("ANTHROPIC_API_KEY", "test-api-key") + + duplicate_tools = [tools[0], tools[0]] + with pytest.raises(ValueError): + AnthropicChatGenerator(tools=duplicate_tools) + + def test_init_with_parameters(self, monkeypatch): + """ + Test that the AnthropicChatGenerator component initializes with parameters. + """ + tool = Tool(name="name", description="description", parameters={"x": {"type": "string"}}, function=lambda x: x) + + monkeypatch.setenv("OPENAI_TIMEOUT", "100") + monkeypatch.setenv("OPENAI_MAX_RETRIES", "10") component = AnthropicChatGenerator( api_key=Secret.from_token("test-api-key"), model="claude-3-5-sonnet-20240620", streaming_callback=print_streaming_chunk, generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"}, - ignore_tools_thinking_messages=False, + tools=[tool], + ) + assert component.client.api_key == "test-api-key" + assert component.model == "claude-3-5-sonnet-20240620" + assert component.streaming_callback is print_streaming_chunk + assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"} + assert component.tools == [tool] + + def test_init_with_parameters_and_env_vars(self, monkeypatch): + """ + Test that the AnthropicChatGenerator component initializes with parameters and env vars. + """ + monkeypatch.setenv("OPENAI_TIMEOUT", "100") + monkeypatch.setenv("OPENAI_MAX_RETRIES", "10") + component = AnthropicChatGenerator( + model="claude-3-5-sonnet-20240620", + api_key=Secret.from_token("test-api-key"), + streaming_callback=print_streaming_chunk, + generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"}, ) assert component.client.api_key == "test-api-key" assert component.model == "claude-3-5-sonnet-20240620" assert component.streaming_callback is print_streaming_chunk assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"} - assert component.ignore_tools_thinking_messages is False def test_to_dict_default(self, monkeypatch): + """ + Test that the AnthropicChatGenerator component can be serialized to a dictionary. + """ monkeypatch.setenv("ANTHROPIC_API_KEY", "test-api-key") component = AnthropicChatGenerator() data = component.to_dict() assert data == { "type": "haystack_integrations.components.generators.anthropic.chat.chat_generator.AnthropicChatGenerator", "init_parameters": { - "api_key": {"env_vars": ["ANTHROPIC_API_KEY"], "strict": True, "type": "env_var"}, + "api_key": {"env_vars": ["ANTHROPIC_API_KEY"], "type": "env_var", "strict": True}, "model": "claude-3-5-sonnet-20240620", "streaming_callback": None, - "generation_kwargs": {}, "ignore_tools_thinking_messages": True, + "generation_kwargs": {}, + "tools": None, }, } def test_to_dict_with_parameters(self, monkeypatch): + """ + Test that the AnthropicChatGenerator component can be serialized to a dictionary with parameters. + """ + tool = Tool(name="name", description="description", parameters={"x": {"type": "string"}}, function=print) + monkeypatch.setenv("ENV_VAR", "test-api-key") component = AnthropicChatGenerator( api_key=Secret.from_env_var("ENV_VAR"), + model="claude-3-5-sonnet-20240620", streaming_callback=print_streaming_chunk, generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"}, + tools=[tool], ) data = component.to_dict() + assert data == { "type": "haystack_integrations.components.generators.anthropic.chat.chat_generator.AnthropicChatGenerator", "init_parameters": { - "api_key": {"env_vars": ["ENV_VAR"], "strict": True, "type": "env_var"}, + "api_key": {"env_vars": ["ENV_VAR"], "type": "env_var", "strict": True}, "model": "claude-3-5-sonnet-20240620", "streaming_callback": "haystack.components.generators.utils.print_streaming_chunk", - "generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"}, "ignore_tools_thinking_messages": True, + "generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"}, + "tools": [ + { + "data": { + "description": "description", + "function": "builtins.print", + "name": "name", + "parameters": { + "x": { + "type": "string", + }, + }, + }, + "type": "haystack.tools.tool.Tool", + } + ], }, } def test_from_dict(self, monkeypatch): + """ + Test that the AnthropicChatGenerator component can be deserialized from a dictionary. + """ monkeypatch.setenv("ANTHROPIC_API_KEY", "fake-api-key") data = { "type": "haystack_integrations.components.generators.anthropic.chat.chat_generator.AnthropicChatGenerator", "init_parameters": { - "api_key": {"env_vars": ["ANTHROPIC_API_KEY"], "strict": True, "type": "env_var"}, + "api_key": {"env_vars": ["ANTHROPIC_API_KEY"], "type": "env_var", "strict": True}, "model": "claude-3-5-sonnet-20240620", "streaming_callback": "haystack.components.generators.utils.print_streaming_chunk", "generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"}, - "ignore_tools_thinking_messages": True, + "tools": [ + { + "type": "haystack.tools.tool.Tool", + "data": { + "description": "description", + "function": "builtins.print", + "name": "name", + "parameters": { + "x": { + "type": "string", + }, + }, + }, + }, + ], }, } component = AnthropicChatGenerator.from_dict(data) + + assert isinstance(component, AnthropicChatGenerator) assert component.model == "claude-3-5-sonnet-20240620" assert component.streaming_callback is print_streaming_chunk assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"} assert component.api_key == Secret.from_env_var("ANTHROPIC_API_KEY") + assert component.tools == [ + Tool(name="name", description="description", parameters={"x": {"type": "string"}}, function=print) + ] def test_from_dict_fail_wo_env_var(self, monkeypatch): + """ + Test that the AnthropicChatGenerator component fails to deserialize from a dictionary without an API key. + """ monkeypatch.delenv("ANTHROPIC_API_KEY", raising=False) data = { "type": "haystack_integrations.components.generators.anthropic.chat.chat_generator.AnthropicChatGenerator", "init_parameters": { - "api_key": {"env_vars": ["ANTHROPIC_API_KEY"], "strict": True, "type": "env_var"}, + "api_key": {"env_vars": ["ANTHROPIC_API_KEY"], "type": "env_var", "strict": True}, "model": "claude-3-5-sonnet-20240620", "streaming_callback": "haystack.components.generators.utils.print_streaming_chunk", "generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"}, - "ignore_tools_thinking_messages": True, }, } - with pytest.raises(ValueError, match="None of the .* environment variables are set"): + with pytest.raises(ValueError): AnthropicChatGenerator.from_dict(data) def test_run(self, chat_messages, mock_chat_completion): @@ -125,23 +262,323 @@ def test_run(self, chat_messages, mock_chat_completion): assert len(response["replies"]) == 1 assert [isinstance(reply, ChatMessage) for reply in response["replies"]] - def test_run_with_params(self, chat_messages, mock_chat_completion): + def test_run_with_params(self, chat_messages, mock_anthropic_completion): + """ + Test that the AnthropicChatGenerator component can run with parameters. + """ component = AnthropicChatGenerator( api_key=Secret.from_token("test-api-key"), generation_kwargs={"max_tokens": 10, "temperature": 0.5} ) response = component.run(chat_messages) - # check that the component calls the Anthropic API with the correct parameters - _, kwargs = mock_chat_completion.call_args + # Check that the component calls the Anthropic API with the correct parameters + _, kwargs = mock_anthropic_completion.call_args assert kwargs["max_tokens"] == 10 assert kwargs["temperature"] == 0.5 - # check that the component returns the correct response + # Check that the component returns the correct response assert isinstance(response, dict) assert "replies" in response assert isinstance(response["replies"], list) assert len(response["replies"]) == 1 - assert [isinstance(reply, ChatMessage) for reply in response["replies"]] + assert isinstance(response["replies"][0], ChatMessage) + assert "Hello! I'm Claude." in response["replies"][0].text + assert response["replies"][0].meta["model"] == "claude-3-5-sonnet-20240620" + assert response["replies"][0].meta["finish_reason"] == "end_turn" + + def test_check_duplicate_tool_names(self, tools): + """Test that the AnthropicChatGenerator component fails to initialize with duplicate tool names.""" + with pytest.raises(ValueError): + AnthropicChatGenerator(tools=tools + tools) + + def test_convert_anthropic_chunk_to_streaming_chunk(self): + """ + Test converting Anthropic stream events to Haystack StreamingChunks + """ + component = AnthropicChatGenerator(api_key=Secret.from_token("test-api-key")) + + # Test text delta chunk + text_delta_chunk = ContentBlockDeltaEvent( + type="content_block_delta", index=0, delta=TextDelta(type="text_delta", text="Hello, world!") + ) + streaming_chunk = component._convert_anthropic_chunk_to_streaming_chunk(text_delta_chunk) + assert streaming_chunk.content == "Hello, world!" + assert streaming_chunk.meta == { + "type": "content_block_delta", + "index": 0, + "delta": {"type": "text_delta", "text": "Hello, world!"}, + } + + # Test non-text chunk (should have empty content) + message_start_chunk = MessageStartEvent( + type="message_start", + message={ + "id": "msg_123", + "type": "message", + "role": "assistant", + "content": [], + "model": "claude-3-5-sonnet-20240620", + "stop_reason": None, + "stop_sequence": None, + "usage": {"input_tokens": 25, "output_tokens": 1}, + }, + ) + streaming_chunk = component._convert_anthropic_chunk_to_streaming_chunk(message_start_chunk) + assert streaming_chunk.content == "" + assert streaming_chunk.meta == { + "type": "message_start", + "message": { + "id": "msg_123", + "type": "message", + "role": "assistant", + "content": [], + "model": "claude-3-5-sonnet-20240620", + "stop_reason": None, + "stop_sequence": None, + "usage": { + "input_tokens": 25, + "output_tokens": 1, + "cache_creation_input_tokens": None, + "cache_read_input_tokens": None, + }, + }, + } + + # Test tool use chunk (should have empty content) + tool_use_chunk = ContentBlockStartEvent( + type="content_block_start", + index=1, + content_block={"type": "tool_use", "id": "toolu_123", "name": "weather", "input": {"city": "Paris"}}, + ) + streaming_chunk = component._convert_anthropic_chunk_to_streaming_chunk(tool_use_chunk) + assert streaming_chunk.content == "" + assert streaming_chunk.meta == { + "type": "content_block_start", + "index": 1, + "content_block": {"type": "tool_use", "id": "toolu_123", "name": "weather", "input": {"city": "Paris"}}, + } + + def test_convert_streaming_chunks_to_chat_message(self): + """ + Test converting streaming chunks to a chat message with tool calls + """ + # Create a sequence of streaming chunks that simulate Anthropic's response + chunks = [ + # Initial text content + StreamingChunk( + content="", + meta={"type": "content_block_start", "index": 0, "content_block": {"type": "text", "text": ""}}, + ), + StreamingChunk( + content="Let me check", + meta={ + "type": "content_block_delta", + "index": 0, + "delta": {"type": "text_delta", "text": "Let me check"}, + }, + ), + StreamingChunk( + content=" the weather", + meta={ + "type": "content_block_delta", + "index": 0, + "delta": {"type": "text_delta", "text": " the weather"}, + }, + ), + StreamingChunk(content="", meta={"type": "content_block_stop", "index": 0}), + # Tool use content + StreamingChunk( + content="", + meta={ + "type": "content_block_start", + "index": 1, + "content_block": {"type": "tool_use", "id": "toolu_123", "name": "weather", "input": {}}, + }, + ), + StreamingChunk( + content="", + meta={ + "type": "content_block_delta", + "index": 1, + "delta": {"type": "input_json_delta", "partial_json": '{"city":'}, + }, + ), + StreamingChunk( + content="", + meta={ + "type": "content_block_delta", + "index": 1, + "delta": {"type": "input_json_delta", "partial_json": ' "Paris"}'}, + }, + ), + StreamingChunk(content="", meta={"type": "content_block_stop", "index": 1}), + # Final message delta + StreamingChunk( + content="", + meta={ + "type": "message_delta", + "delta": {"stop_reason": "tool_use", "stop_sequence": None}, + "usage": {"completion_tokens": 40}, + }, + ), + ] + + component = AnthropicChatGenerator(api_key=Secret.from_token("test-api-key")) + message = component._convert_streaming_chunks_to_chat_message(chunks, model="claude-3-sonnet") + + # Verify the message content + assert message.text == "Let me check the weather" + + # Verify tool calls + assert len(message.tool_calls) == 1 + tool_call = message.tool_calls[0] + assert tool_call.id == "toolu_123" + assert tool_call.tool_name == "weather" + assert tool_call.arguments == {"city": "Paris"} + + # Verify meta information + assert message._meta["model"] == "claude-3-sonnet" + assert message._meta["index"] == 0 + assert message._meta["finish_reason"] == "tool_use" + assert message._meta["usage"] == {"completion_tokens": 40} + + def test_convert_streaming_chunks_to_chat_message_malformed_json(self, caplog): + """ + Test converting streaming chunks with malformed JSON in tool arguments (increases coverage) + """ + chunks = [ + # Initial text content + StreamingChunk( + content="", + meta={"type": "content_block_start", "index": 0, "content_block": {"type": "text", "text": ""}}, + ), + StreamingChunk( + content="Let me check the weather", + meta={ + "type": "content_block_delta", + "index": 0, + "delta": {"type": "text_delta", "text": "Let me check the weather"}, + }, + ), + StreamingChunk(content="", meta={"type": "content_block_stop", "index": 0}), + # Tool use content with malformed JSON + StreamingChunk( + content="", + meta={ + "type": "content_block_start", + "index": 1, + "content_block": {"type": "tool_use", "id": "toolu_123", "name": "weather", "input": {}}, + }, + ), + StreamingChunk( + content="", + meta={ + "type": "content_block_delta", + "index": 1, + "delta": {"type": "input_json_delta", "partial_json": '{"city":'}, + }, + ), + StreamingChunk( + content="", + meta={ + "type": "content_block_delta", + "index": 1, + "delta": { + "type": "input_json_delta", + "partial_json": ' "Paris', # Missing closing quote and brace, malformed JSON + }, + }, + ), + StreamingChunk(content="", meta={"type": "content_block_stop", "index": 1}), + # Final message delta + StreamingChunk( + content="", + meta={ + "type": "message_delta", + "delta": {"stop_reason": "tool_use", "stop_sequence": None}, + "usage": {"completion_tokens": 40}, + }, + ), + ] + + component = AnthropicChatGenerator(api_key=Secret.from_token("test-api-key")) + message = component._convert_streaming_chunks_to_chat_message(chunks, model="claude-3-sonnet") + + # Verify the message content is preserve + assert message.text == "Let me check the weather" + + # But the tool_calls are empty + assert len(message.tool_calls) == 0 + + # and we have logged a warning + with caplog.at_level(logging.WARNING): + assert "Anthropic returned a malformed JSON string" in caplog.text + + def test_serde_in_pipeline(self): + tool = Tool(name="name", description="description", parameters={"x": {"type": "string"}}, function=print) + + generator = AnthropicChatGenerator( + api_key=Secret.from_env_var("ANTHROPIC_API_KEY", strict=False), + model="claude-3-5-sonnet-20240620", + generation_kwargs={"temperature": 0.6}, + tools=[tool], + ) + + pipeline = Pipeline() + pipeline.add_component("generator", generator) + + pipeline_dict = pipeline.to_dict() + type_ = "haystack_integrations.components.generators.anthropic.chat.chat_generator.AnthropicChatGenerator" + assert pipeline_dict == { + "metadata": {}, + "max_runs_per_component": 100, + "components": { + "generator": { + "type": type_, + "init_parameters": { + "api_key": {"type": "env_var", "env_vars": ["ANTHROPIC_API_KEY"], "strict": False}, + "model": "claude-3-5-sonnet-20240620", + "generation_kwargs": {"temperature": 0.6}, + "ignore_tools_thinking_messages": True, + "streaming_callback": None, + "tools": [ + { + "type": "haystack.tools.tool.Tool", + "data": { + "name": "name", + "description": "description", + "parameters": {"x": {"type": "string"}}, + "function": "builtins.print", + }, + } + ], + }, + } + }, + "connections": [], + } + + pipeline_yaml = pipeline.dumps() + + new_pipeline = Pipeline.loads(pipeline_yaml) + assert new_pipeline == pipeline + + @pytest.mark.skipif( + not os.environ.get("ANTHROPIC_API_KEY", None), + reason="Export an env var called ANTHROPIC_API_KEY containing the Anthropic API key to run this test.", + ) + @pytest.mark.integration + def test_live_run(self): + """ + Integration test that the AnthropicChatGenerator component can run with default parameters. + """ + component = AnthropicChatGenerator() + results = component.run(messages=[ChatMessage.from_user("What's the capital of France?")]) + assert len(results["replies"]) == 1 + message: ChatMessage = results["replies"][0] + assert "Paris" in message.text + assert "claude-3-5-sonnet-20240620" in message.meta["model"] + assert message.meta["finish_reason"] == "end_turn" @pytest.mark.skipif( not os.environ.get("ANTHROPIC_API_KEY", None), @@ -153,96 +590,349 @@ def test_live_run_wrong_model(self, chat_messages): with pytest.raises(anthropic.NotFoundError): component.run(chat_messages) + @pytest.mark.skipif( + not os.environ.get("ANTHROPIC_API_KEY", None), + reason="Export an env var called ANTHROPIC_API_KEY containing the OpenAI API key to run this test.", + ) + @pytest.mark.integration + def test_live_run_streaming(self): + """ + Integration test that the AnthropicChatGenerator component can run with streaming. + """ + + class Callback: + def __init__(self): + self.responses = "" + self.counter = 0 + + def __call__(self, chunk: StreamingChunk) -> None: + self.counter += 1 + self.responses += chunk.content if chunk.content else "" + + callback = Callback() + component = AnthropicChatGenerator(streaming_callback=callback) + results = component.run([ChatMessage.from_user("What's the capital of France?")]) + + assert len(results["replies"]) == 1 + message: ChatMessage = results["replies"][0] + assert "Paris" in message.text + + assert "claude-3-5-sonnet-20240620" in message.meta["model"] + assert message.meta["finish_reason"] == "end_turn" + + assert callback.counter > 1 + assert "Paris" in callback.responses + + def test_convert_message_to_anthropic_format(self): + """ + Test that the AnthropicChatGenerator component can convert a ChatMessage to Anthropic format. + """ + messages = [ChatMessage.from_system("You are good assistant")] + assert _convert_messages_to_anthropic_format(messages) == ( + [{"type": "text", "text": "You are good assistant"}], + [], + ) + + messages = [ChatMessage.from_user("I have a question")] + assert _convert_messages_to_anthropic_format(messages) == ( + [], + [{"role": "user", "content": [{"type": "text", "text": "I have a question"}]}], + ) + + messages = [ChatMessage.from_assistant(text="I have an answer", meta={"finish_reason": "stop"})] + assert _convert_messages_to_anthropic_format(messages) == ( + [], + [{"role": "assistant", "content": [{"type": "text", "text": "I have an answer"}]}], + ) + + messages = [ + ChatMessage.from_assistant( + tool_calls=[ToolCall(id="123", tool_name="weather", arguments={"city": "Paris"})] + ) + ] + result = _convert_messages_to_anthropic_format(messages) + assert result == ( + [], + [ + { + "role": "assistant", + "content": [{"type": "tool_use", "id": "123", "name": "weather", "input": {"city": "Paris"}}], + } + ], + ) + + messages = [ + ChatMessage.from_assistant( + text="", # this should not happen, but we should handle it without errors + tool_calls=[ToolCall(id="123", tool_name="weather", arguments={"city": "Paris"})], + ) + ] + result = _convert_messages_to_anthropic_format(messages) + assert result == ( + [], + [ + { + "role": "assistant", + "content": [{"type": "tool_use", "id": "123", "name": "weather", "input": {"city": "Paris"}}], + } + ], + ) + + tool_result = json.dumps({"weather": "sunny", "temperature": "25"}) + messages = [ + ChatMessage.from_tool( + tool_result=tool_result, origin=ToolCall(id="123", tool_name="weather", arguments={"city": "Paris"}) + ) + ] + assert _convert_messages_to_anthropic_format(messages) == ( + [], + [ + { + "role": "user", + "content": [ + { + "type": "tool_result", + "tool_use_id": "123", + "content": [{"type": "text", "text": '{"weather": "sunny", "temperature": "25"}'}], + "is_error": False, + } + ], + } + ], + ) + + messages = [ + ChatMessage.from_assistant( + text="For that I'll need to check the weather", + tool_calls=[ToolCall(id="123", tool_name="weather", arguments={"city": "Paris"})], + ) + ] + result = _convert_messages_to_anthropic_format(messages) + assert result == ( + [], + [ + { + "role": "assistant", + "content": [ + {"type": "text", "text": "For that I'll need to check the weather"}, + {"type": "tool_use", "id": "123", "name": "weather", "input": {"city": "Paris"}}, + ], + } + ], + ) + + def test_convert_message_to_anthropic_format_complex(self): + """ + Test that the AnthropicChatGenerator can convert a complex sequence of ChatMessages to Anthropic format. + In particular, we check that different tool results are packed in a single dictionary with role=user. + """ + + messages = [ + ChatMessage.from_system("You are good assistant"), + ChatMessage.from_user("What's the weather like in Paris? And how much is 2+2?"), + ChatMessage.from_assistant( + text="", + tool_calls=[ + ToolCall(id="123", tool_name="weather", arguments={"city": "Paris"}), + ToolCall(id="456", tool_name="math", arguments={"expression": "2+2"}), + ], + ), + ChatMessage.from_tool( + tool_result="22° C", origin=ToolCall(id="123", tool_name="weather", arguments={"city": "Paris"}) + ), + ChatMessage.from_tool( + tool_result="4", origin=ToolCall(id="456", tool_name="math", arguments={"expression": "2+2"}) + ), + ] + + system_messages, non_system_messages = _convert_messages_to_anthropic_format(messages) + + assert system_messages == [{"type": "text", "text": "You are good assistant"}] + assert non_system_messages == [ + { + "role": "user", + "content": [{"type": "text", "text": "What's the weather like in Paris? And how much is 2+2?"}], + }, + { + "role": "assistant", + "content": [ + {"type": "tool_use", "id": "123", "name": "weather", "input": {"city": "Paris"}}, + {"type": "tool_use", "id": "456", "name": "math", "input": {"expression": "2+2"}}, + ], + }, + { + "role": "user", + "content": [ + { + "type": "tool_result", + "tool_use_id": "123", + "content": [{"type": "text", "text": "22° C"}], + "is_error": False, + }, + { + "type": "tool_result", + "tool_use_id": "456", + "content": [{"type": "text", "text": "4"}], + "is_error": False, + }, + ], + }, + ] + + def test_convert_message_to_anthropic_invalid(self): + """ + Test that the AnthropicChatGenerator component fails to convert an invalid ChatMessage to Anthropic format. + """ + message = ChatMessage(_role=ChatRole.ASSISTANT, _content=[]) + with pytest.raises(ValueError): + _convert_messages_to_anthropic_format([message]) + + tool_call_null_id = ToolCall(id=None, tool_name="weather", arguments={"city": "Paris"}) + message = ChatMessage.from_assistant(tool_calls=[tool_call_null_id]) + with pytest.raises(ValueError): + _convert_messages_to_anthropic_format([message]) + + message = ChatMessage.from_tool(tool_result="result", origin=tool_call_null_id) + with pytest.raises(ValueError): + _convert_messages_to_anthropic_format([message]) + @pytest.mark.skipif( not os.environ.get("ANTHROPIC_API_KEY", None), reason="Export an env var called ANTHROPIC_API_KEY containing the Anthropic API key to run this test.", ) @pytest.mark.integration - def test_default_inference_params(self, chat_messages): - client = AnthropicChatGenerator() - response = client.run(chat_messages) - - assert "replies" in response, "Response does not contain 'replies' key" - replies = response["replies"] - assert isinstance(replies, list), "Replies is not a list" - assert len(replies) > 0, "No replies received" - - first_reply = replies[0] - assert isinstance(first_reply, ChatMessage), "First reply is not a ChatMessage instance" - assert first_reply.text, "First reply has no text" - assert ChatMessage.is_from(first_reply, ChatRole.ASSISTANT), "First reply is not from the assistant" - assert "paris" in first_reply.text.lower(), "First reply does not contain 'paris'" - assert first_reply.meta, "First reply has no metadata" + def test_live_run_with_tools(self, tools): + """ + Integration test that the AnthropicChatGenerator component can run with tools. + """ + initial_messages = [ChatMessage.from_user("What's the weather like in Paris?")] + component = AnthropicChatGenerator(tools=tools) + results = component.run(messages=initial_messages) + + assert len(results["replies"]) == 1 + message = results["replies"][0] + + assert message.tool_calls + tool_call = message.tool_call + assert isinstance(tool_call, ToolCall) + assert tool_call.id is not None + assert tool_call.tool_name == "weather" + assert tool_call.arguments == {"city": "Paris"} + assert message.meta["finish_reason"] == "tool_use" + + new_messages = [ + *initial_messages, + message, + ChatMessage.from_tool(tool_result="22° C", origin=tool_call), + ] + # the model tends to make tool calls if provided with tools, so we don't pass them here + results = component.run(new_messages, generation_kwargs={"max_tokens": 50}) + + assert len(results["replies"]) == 1 + final_message = results["replies"][0] + assert not final_message.tool_calls + assert len(final_message.text) > 0 + assert "paris" in final_message.text.lower() @pytest.mark.skipif( not os.environ.get("ANTHROPIC_API_KEY", None), reason="Export an env var called ANTHROPIC_API_KEY containing the Anthropic API key to run this test.", ) @pytest.mark.integration - def test_default_inference_with_streaming(self, chat_messages): - streaming_callback_called = False - paris_found_in_response = False - - def streaming_callback(chunk: StreamingChunk): - nonlocal streaming_callback_called, paris_found_in_response - streaming_callback_called = True - assert isinstance(chunk, StreamingChunk) - assert chunk.content is not None - if not paris_found_in_response: - paris_found_in_response = "paris" in chunk.content.lower() - - client = AnthropicChatGenerator(streaming_callback=streaming_callback) - response = client.run(chat_messages) - - assert streaming_callback_called, "Streaming callback was not called" - assert paris_found_in_response, "The streaming callback response did not contain 'paris'" - replies = response["replies"] - assert isinstance(replies, list), "Replies is not a list" - assert len(replies) > 0, "No replies received" - - first_reply = replies[0] - assert isinstance(first_reply, ChatMessage), "First reply is not a ChatMessage instance" - assert first_reply.text, "First reply has no text" - assert ChatMessage.is_from(first_reply, ChatRole.ASSISTANT), "First reply is not from the assistant" - assert "paris" in first_reply.text.lower(), "First reply does not contain 'paris'" - assert first_reply.meta, "First reply has no metadata" + def test_live_run_with_tools_streaming(self, tools): + """ + Integration test that the AnthropicChatGenerator component can run with tools and streaming. + """ + initial_messages = [ChatMessage.from_user("What's the weather like in Paris?")] + component = AnthropicChatGenerator(tools=tools, streaming_callback=print_streaming_chunk) + results = component.run(messages=initial_messages) + + assert len(results["replies"]) == 1 + message = results["replies"][0] + + # this is Antropic thinking message prior to tool call + assert message.text is not None + assert "weather" in message.text.lower() + assert "paris" in message.text.lower() + + # now we have the tool call + assert message.tool_calls + tool_call = message.tool_call + assert isinstance(tool_call, ToolCall) + assert tool_call.id is not None + assert tool_call.tool_name == "weather" + assert tool_call.arguments == {"city": "Paris"} + assert message.meta["finish_reason"] == "tool_use" + + new_messages = [ + *initial_messages, + message, + ChatMessage.from_tool(tool_result="22° C", origin=tool_call), + ] + results = component.run(new_messages) + assert len(results["replies"]) == 1 + final_message = results["replies"][0] + assert not final_message.tool_calls + assert len(final_message.text) > 0 + assert "paris" in final_message.text.lower() @pytest.mark.skipif( not os.environ.get("ANTHROPIC_API_KEY", None), reason="Export an env var called ANTHROPIC_API_KEY containing the Anthropic API key to run this test.", ) @pytest.mark.integration - def test_tools_use(self): - # See https://docs.anthropic.com/en/docs/tool-use for more information - tools_schema = { - "name": "get_stock_price", - "description": "Retrieves the current stock price for a given ticker symbol.", - "input_schema": { - "type": "object", - "properties": { - "ticker": {"type": "string", "description": "The stock ticker symbol, e.g. AAPL for Apple Inc."} - }, - "required": ["ticker"], - }, - } - client = AnthropicChatGenerator() - response = client.run( - messages=[ChatMessage.from_user("What is the current price of AAPL?")], - generation_kwargs={"tools": [tools_schema]}, - ) - replies = response["replies"] - assert isinstance(replies, list), "Replies is not a list" - assert len(replies) > 0, "No replies received" - - first_reply = replies[0] - assert isinstance(first_reply, ChatMessage), "First reply is not a ChatMessage instance" - assert first_reply.text, "First reply has no text" - assert ChatMessage.is_from(first_reply, ChatRole.ASSISTANT), "First reply is not from the assistant" - assert "get_stock_price" in first_reply.text.lower(), "First reply does not contain get_stock_price" - assert first_reply.meta, "First reply has no metadata" - fc_response = json.loads(first_reply.text) - assert "name" in fc_response, "First reply does not contain name of the tool" - assert "input" in fc_response, "First reply does not contain input of the tool" + def test_live_run_with_parallel_tools(self, tools): + """ + Integration test that the AnthropicChatGenerator component can run with parallel tools. + """ + initial_messages = [ChatMessage.from_user("What's the weather like in Paris and Berlin?")] + component = AnthropicChatGenerator(tools=tools) + results = component.run(messages=initial_messages) + + assert len(results["replies"]) == 1 + message = results["replies"][0] + + # now we have the tool call + assert len(message.tool_calls) == 2 + tool_call_paris = message.tool_calls[0] + assert isinstance(tool_call_paris, ToolCall) + assert tool_call_paris.id is not None + assert tool_call_paris.tool_name == "weather" + assert tool_call_paris.arguments["city"] in {"Paris", "Berlin"} + assert message.meta["finish_reason"] == "tool_use" + + tool_call_berlin = message.tool_calls[1] + assert isinstance(tool_call_berlin, ToolCall) + assert tool_call_berlin.id is not None + assert tool_call_berlin.tool_name == "weather" + assert tool_call_berlin.arguments["city"] in {"Berlin", "Paris"} + + # Anthropic expects results from both tools in the same message + # https://docs.anthropic.com/en/docs/build-with-claude/tool-use#handling-tool-use-and-tool-result-content-blocks + # [optional] Continue the conversation by sending a new message with the role of user, and a content block + # containing the tool_result type and the following information: + # tool_use_id: The id of the tool use request this is a result for. + # content: The result of the tool, as a string (e.g. "content": "15 degrees") or list of + # nested content blocks (e.g. "content": [{"type": "text", "text": "15 degrees"}]). + # These content blocks can use the text or image types. + # is_error (optional): Set to true if the tool execution resulted in an error. + new_messages = [ + *initial_messages, + message, + ChatMessage.from_tool(tool_result="22° C", origin=tool_call_paris, error=False), + ChatMessage.from_tool(tool_result="12° C", origin=tool_call_berlin, error=False), + ] + + # Response from the model contains results from both tools + results = component.run(new_messages) + message = results["replies"][0] + assert not message.tool_calls + assert len(message.text) > 0 + assert "paris" in message.text.lower() + assert "berlin" in message.text.lower() + assert "22°" in message.text + assert "12°" in message.text + assert message.meta["finish_reason"] == "end_turn" def test_prompt_caching_enabled(self, monkeypatch): """ @@ -267,7 +957,7 @@ def test_prompt_caching_cache_control_without_extra_headers(self, monkeypatch, m # Add cache_control to messages for msg in messages: - msg.meta["cache_control"] = {"type": "ephemeral"} + msg._meta["cache_control"] = {"type": "ephemeral"} # Invoke run with messages component.run(messages) @@ -333,57 +1023,10 @@ def test_from_dict_with_prompt_caching(self, monkeypatch): component = AnthropicChatGenerator.from_dict(data) assert component.generation_kwargs["extra_headers"]["anthropic-beta"] == "prompt-caching-2024-07-31" - def test_convert_messages_to_anthropic_format(self, monkeypatch): - monkeypatch.setenv("ANTHROPIC_API_KEY", "test-api-key") - generator = AnthropicChatGenerator() - - # Test scenario 1: Regular user and assistant messages - messages = [ - ChatMessage.from_user("Hello"), - ChatMessage.from_assistant("Hi there!"), - ] - result = generator._convert_to_anthropic_format(messages) - assert result == [ - {"role": "user", "content": "Hello"}, - {"role": "assistant", "content": "Hi there!"}, - ] - - # Test scenario 2: System message - messages = [ChatMessage.from_system("You are a helpful assistant.")] - result = generator._convert_to_anthropic_format(messages) - assert result == [{"type": "text", "text": "You are a helpful assistant."}] - - # Test scenario 3: Mixed message types - messages = [ - ChatMessage.from_system("Be concise."), - ChatMessage.from_user("What's AI?"), - ChatMessage.from_assistant("Artificial Intelligence."), - ] - result = generator._convert_to_anthropic_format(messages) - assert result == [ - {"type": "text", "text": "Be concise."}, - {"role": "user", "content": "What's AI?"}, - {"role": "assistant", "content": "Artificial Intelligence."}, - ] - - # Test scenario 4: metadata - messages = [ - ChatMessage.from_user("What's AI?"), - ChatMessage.from_assistant("Artificial Intelligence.", meta={"confidence": 0.9}), - ] - result = generator._convert_to_anthropic_format(messages) - assert result == [ - {"role": "user", "content": "What's AI?"}, - {"role": "assistant", "content": "Artificial Intelligence.", "confidence": 0.9}, - ] - - # Test scenario 5: Empty message list - assert generator._convert_to_anthropic_format([]) == [] - @pytest.mark.integration @pytest.mark.skipif(not os.environ.get("ANTHROPIC_API_KEY", None), reason="ANTHROPIC_API_KEY not set") @pytest.mark.parametrize("cache_enabled", [True, False]) - def test_prompt_caching(self, cache_enabled): + def test_prompt_caching_live_run(self, cache_enabled): generation_kwargs = {"extra_headers": {"anthropic-beta": "prompt-caching-2024-07-31"}} if cache_enabled else {} claude_llm = AnthropicChatGenerator( @@ -393,7 +1036,7 @@ def test_prompt_caching(self, cache_enabled): # see https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching#cache-limitations system_message = ChatMessage.from_system("This is the cached, here we make it at least 1024 tokens long." * 70) if cache_enabled: - system_message.meta["cache_control"] = {"type": "ephemeral"} + system_message._meta["cache_control"] = {"type": "ephemeral"} messages = [system_message, ChatMessage.from_user("What's in cached content?")] result = claude_llm.run(messages) diff --git a/integrations/anthropic/tests/test_vertex_chat_generator.py b/integrations/anthropic/tests/test_vertex_chat_generator.py index 6c3a30d89..43df6fab3 100644 --- a/integrations/anthropic/tests/test_vertex_chat_generator.py +++ b/integrations/anthropic/tests/test_vertex_chat_generator.py @@ -57,6 +57,7 @@ def test_to_dict_default(self): "streaming_callback": None, "generation_kwargs": {}, "ignore_tools_thinking_messages": True, + "tools": None, }, } @@ -80,6 +81,7 @@ def test_to_dict_with_parameters(self): "streaming_callback": "haystack.components.generators.utils.print_streaming_chunk", "generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"}, "ignore_tools_thinking_messages": True, + "tools": None, }, } From 770d4f1a1a7a223aec101080430fd1756d7c69c1 Mon Sep 17 00:00:00 2001 From: Stefano Fiorucci Date: Mon, 20 Jan 2025 09:40:36 +0100 Subject: [PATCH 194/229] pin transformers!=4.48.0 (#1302) --- integrations/amazon_bedrock/pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/integrations/amazon_bedrock/pyproject.toml b/integrations/amazon_bedrock/pyproject.toml index 872d4933b..ad9754a33 100644 --- a/integrations/amazon_bedrock/pyproject.toml +++ b/integrations/amazon_bedrock/pyproject.toml @@ -23,7 +23,7 @@ classifiers = [ "Programming Language :: Python :: Implementation :: CPython", "Programming Language :: Python :: Implementation :: PyPy", ] -dependencies = ["haystack-ai", "boto3>=1.28.57", "transformers"] +dependencies = ["haystack-ai", "boto3>=1.28.57", "transformers!=4.48.0"] [project.urls] Documentation = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/amazon_bedrock#readme" From e973a190aae700d6804461eb387376c9db7f00fc Mon Sep 17 00:00:00 2001 From: HaystackBot Date: Mon, 20 Jan 2025 08:47:36 +0000 Subject: [PATCH 195/229] Update the changelog --- integrations/amazon_bedrock/CHANGELOG.md | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/integrations/amazon_bedrock/CHANGELOG.md b/integrations/amazon_bedrock/CHANGELOG.md index 6a15d4ad2..368b14f9e 100644 --- a/integrations/amazon_bedrock/CHANGELOG.md +++ b/integrations/amazon_bedrock/CHANGELOG.md @@ -1,5 +1,11 @@ # Changelog +## [integrations/amazon_bedrock-v2.1.2] - 2025-01-20 + +### 🌀 Miscellaneous + +- Fix: Bedrock - pin `transformers!=4.48.0` (#1302) + ## [integrations/amazon_bedrock-v2.1.1] - 2024-12-18 ### 🐛 Bug Fixes From ae2779334da2cd9a45be4941fe311d8a454b2d89 Mon Sep 17 00:00:00 2001 From: Stefano Fiorucci Date: Mon, 20 Jan 2025 12:25:02 +0100 Subject: [PATCH 196/229] fix: Optimum - do not explictly pin `transformers` (#1303) * pin transformers!=4.48.0 * lint * try removing explicit transformers dep --- integrations/optimum/pyproject.toml | 1 - integrations/optimum/tests/test_optimum_document_embedder.py | 4 +++- integrations/optimum/tests/test_optimum_text_embedder.py | 4 +++- 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/integrations/optimum/pyproject.toml b/integrations/optimum/pyproject.toml index 6149997ed..d7446f859 100644 --- a/integrations/optimum/pyproject.toml +++ b/integrations/optimum/pyproject.toml @@ -27,7 +27,6 @@ classifiers = [ ] dependencies = [ "haystack-ai", - "transformers[sentencepiece]", # The main export function of Optimum into ONNX has hidden dependencies. # It depends on either "sentence-transformers", "diffusers" or "timm", based # on which model is loaded from HF Hub. diff --git a/integrations/optimum/tests/test_optimum_document_embedder.py b/integrations/optimum/tests/test_optimum_document_embedder.py index 7c8ca02e0..32db9258f 100644 --- a/integrations/optimum/tests/test_optimum_document_embedder.py +++ b/integrations/optimum/tests/test_optimum_document_embedder.py @@ -99,7 +99,9 @@ def test_init_with_parameters(self, mock_check_valid_model): # noqa: ARG002 assert embedder._backend.parameters.optimizer_settings is None assert embedder._backend.parameters.quantizer_settings is None - def test_to_and_from_dict(self, mock_check_valid_model, mock_get_pooling_mode): # noqa: ARG002 + def test_to_and_from_dict(self, mock_check_valid_model, mock_get_pooling_mode, monkeypatch): # noqa: ARG002 + monkeypatch.delenv("HF_API_TOKEN", raising=False) + monkeypatch.delenv("HF_TOKEN", raising=False) component = OptimumDocumentEmbedder() data = component.to_dict() diff --git a/integrations/optimum/tests/test_optimum_text_embedder.py b/integrations/optimum/tests/test_optimum_text_embedder.py index db42ec26d..4343d1a0f 100644 --- a/integrations/optimum/tests/test_optimum_text_embedder.py +++ b/integrations/optimum/tests/test_optimum_text_embedder.py @@ -84,7 +84,9 @@ def test_init_with_parameters(self, mock_check_valid_model): # noqa: ARG002 assert embedder._backend.parameters.optimizer_settings is None assert embedder._backend.parameters.quantizer_settings is None - def test_to_and_from_dict(self, mock_check_valid_model, mock_get_pooling_mode): # noqa: ARG002 + def test_to_and_from_dict(self, mock_check_valid_model, mock_get_pooling_mode, monkeypatch): # noqa: ARG002 + monkeypatch.delenv("HF_API_TOKEN", raising=False) + monkeypatch.delenv("HF_TOKEN", raising=False) component = OptimumTextEmbedder() data = component.to_dict() From 85f1b08cc83090538f9b0c8d375333c8c337a292 Mon Sep 17 00:00:00 2001 From: HaystackBot Date: Mon, 20 Jan 2025 11:26:24 +0000 Subject: [PATCH 197/229] Update the changelog --- integrations/optimum/CHANGELOG.md | 39 ++++++++++++++++++++++++++++--- 1 file changed, 36 insertions(+), 3 deletions(-) diff --git a/integrations/optimum/CHANGELOG.md b/integrations/optimum/CHANGELOG.md index 6699bef7a..8f7025b71 100644 --- a/integrations/optimum/CHANGELOG.md +++ b/integrations/optimum/CHANGELOG.md @@ -1,23 +1,52 @@ # Changelog -## [integrations/optimum-v0.1.1] - 2024-07-04 +## [integrations/optimum-v0.1.2] - 2025-01-20 ### 🐛 Bug Fixes -- Fix docs build (#633) +- Optimum - do not explictly pin `transformers` (#1303) + +### 🧪 Testing + +- Do not retry tests in `hatch run test` command (#954) + +### ⚙️ CI +- Adopt uv as installer (#1142) + +### 🧹 Chores + +- Update ruff linting scripts and settings (#1105) +- Fix linting/isort (#1215) + + +## [integrations/optimum-v0.1.1] - 2024-07-04 +### 🐛 Bug Fixes + +- Fix docs build (#633) - Fix typo in the `ORTModel.inputs_names` field to align with upstream (#866) ### 📚 Documentation - Disable-class-def (#556) -### ⚙️ Miscellaneous Tasks +### ⚙️ CI - Retry tests to reduce flakyness (#836) + +### 🧹 Chores + - Update ruff invocation to include check parameter (#853) +### 🌀 Miscellaneous + +- Make tests show coverage (#566) +- Remove references to Python 3.7 (#601) +- Chore: add license classifiers (#680) +- Chore: change the pydoc renderer class (#718) +- Ci: install `pytest-rerunfailures` where needed; add retry config to `test-cov` script (#845) + ## [integrations/optimum-v0.1.0] - 2024-03-04 ### 🚀 Features @@ -34,4 +63,8 @@ - Fix Optimum embedder examples (#517) +### 🌀 Miscellaneous + +- Refactor `optimum` namespacing + bug fixes (#469) + From 0678b7d13493569d0bc3ae59247180c0814a7bb2 Mon Sep 17 00:00:00 2001 From: HaystackBot Date: Mon, 20 Jan 2025 11:41:34 +0000 Subject: [PATCH 198/229] Update the changelog --- integrations/anthropic/CHANGELOG.md | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/integrations/anthropic/CHANGELOG.md b/integrations/anthropic/CHANGELOG.md index c14a8032d..d361f0f3c 100644 --- a/integrations/anthropic/CHANGELOG.md +++ b/integrations/anthropic/CHANGELOG.md @@ -1,5 +1,13 @@ # Changelog +## [unreleased] + +### 🌀 Miscellaneous + +- Test: remove tests involving serialization of lambdas (#1281) +- Test: remove more tests involving serialization of lambdas (#1285) +- Feat: Anthropic - support for Tools + refactoring (#1300) + ## [integrations/anthropic-v1.2.1] - 2024-12-18 ### 🐛 Bug Fixes From 0c629bbadd0516e467939a5fe6b580a5d6eac1c4 Mon Sep 17 00:00:00 2001 From: Stefano Fiorucci Date: Tue, 21 Jan 2025 12:07:11 +0100 Subject: [PATCH 199/229] Pinecone - increase sleep time (#1307) --- integrations/pinecone/tests/conftest.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/integrations/pinecone/tests/conftest.py b/integrations/pinecone/tests/conftest.py index 074ed978d..1a6574adb 100644 --- a/integrations/pinecone/tests/conftest.py +++ b/integrations/pinecone/tests/conftest.py @@ -13,7 +13,7 @@ from haystack_integrations.document_stores.pinecone import PineconeDocumentStore # This is the approximate time in seconds it takes for the documents to be available -SLEEP_TIME_IN_SECONDS = 15 +SLEEP_TIME_IN_SECONDS = 18 @pytest.fixture() From 17f6a1f7f260514b3481a258518bb45be9758811 Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Tue, 21 Jan 2025 13:29:13 +0100 Subject: [PATCH 200/229] fix: End langfuse generation spans properly (#1301) * End langfuse generation spans properly * Linting --- .../src/haystack_integrations/tracing/langfuse/tracer.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/integrations/langfuse/src/haystack_integrations/tracing/langfuse/tracer.py b/integrations/langfuse/src/haystack_integrations/tracing/langfuse/tracer.py index 1b7187f30..50fab2c8a 100644 --- a/integrations/langfuse/src/haystack_integrations/tracing/langfuse/tracer.py +++ b/integrations/langfuse/src/haystack_integrations/tracing/langfuse/tracer.py @@ -201,7 +201,11 @@ def trace( ) raw_span = span.raw_span() - if isinstance(raw_span, langfuse.client.StatefulSpanClient): + + # In this section, we finalize both regular spans and generation spans created using the LangfuseSpan class. + # It's important to end() these spans to ensure they are properly closed and all relevant data is recorded. + # Note that we do not call end() on the main trace span itself, as its lifecycle is managed differently. + if isinstance(raw_span, (langfuse.client.StatefulSpanClient, langfuse.client.StatefulGenerationClient)): raw_span.end() self._context.pop() From b7ae1de2127d59f42a354195fdc3f7f2769d5a21 Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Tue, 21 Jan 2025 13:29:54 +0100 Subject: [PATCH 201/229] feat: LangfuseConnector - add httpx.Client init param (#1308) * LangfuseConnector - add httpx.Client init param * Add dep * Add test * PR feedback * Improve pydocs * PR feedback - wording * Lint --- integrations/langfuse/pyproject.toml | 8 +++++- .../connectors/langfuse/langfuse_connector.py | 7 +++++ integrations/langfuse/tests/test_tracing.py | 26 ++++++++++++++++++- 3 files changed, 39 insertions(+), 2 deletions(-) diff --git a/integrations/langfuse/pyproject.toml b/integrations/langfuse/pyproject.toml index 92ebc7b8f..4cf8b0e4c 100644 --- a/integrations/langfuse/pyproject.toml +++ b/integrations/langfuse/pyproject.toml @@ -67,7 +67,12 @@ python = ["3.8", "3.9", "3.10", "3.11", "3.12"] [tool.hatch.envs.lint] installer = "uv" detached = true -dependencies = ["pip", "black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243"] +dependencies = [ + "pip", + "black>=23.1.0", + "mypy>=1.0.0", + "ruff>=0.0.243", +] [tool.hatch.envs.lint.scripts] typing = "mypy --install-types --non-interactive --explicit-package-bases {args:src/ tests}" @@ -167,6 +172,7 @@ module = [ "haystack_integrations.*", "pytest.*", "numpy.*", + "httpx.*", ] ignore_missing_imports = true diff --git a/integrations/langfuse/src/haystack_integrations/components/connectors/langfuse/langfuse_connector.py b/integrations/langfuse/src/haystack_integrations/components/connectors/langfuse/langfuse_connector.py index 1762a92de..9418b2270 100644 --- a/integrations/langfuse/src/haystack_integrations/components/connectors/langfuse/langfuse_connector.py +++ b/integrations/langfuse/src/haystack_integrations/components/connectors/langfuse/langfuse_connector.py @@ -1,5 +1,6 @@ from typing import Any, Dict, Optional +import httpx from haystack import component, default_from_dict, default_to_dict, logging, tracing from haystack.utils import Secret, deserialize_secrets_inplace @@ -101,6 +102,7 @@ def __init__( public: bool = False, public_key: Optional[Secret] = Secret.from_env_var("LANGFUSE_PUBLIC_KEY"), # noqa: B008 secret_key: Optional[Secret] = Secret.from_env_var("LANGFUSE_SECRET_KEY"), # noqa: B008 + httpx_client: Optional[httpx.Client] = None, ): """ Initialize the LangfuseConnector component. @@ -112,6 +114,9 @@ def __init__( only accessible to the Langfuse account owner. The default is `False`. :param public_key: The Langfuse public key. Defaults to reading from LANGFUSE_PUBLIC_KEY environment variable. :param secret_key: The Langfuse secret key. Defaults to reading from LANGFUSE_SECRET_KEY environment variable. + :param httpx_client: Optional custom httpx.Client instance to use for Langfuse API calls. Note that when + deserializing a pipeline from YAML, any custom client is discarded and Langfuse will create its own default + client, since HTTPX clients cannot be serialized. """ self.name = name self.public = public @@ -121,6 +126,7 @@ def __init__( tracer=Langfuse( secret_key=secret_key.resolve_value() if secret_key else None, public_key=public_key.resolve_value() if public_key else None, + httpx_client=httpx_client, ), name=name, public=public, @@ -158,6 +164,7 @@ def to_dict(self) -> Dict[str, Any]: public=self.public, secret_key=self.secret_key.to_dict() if self.secret_key else None, public_key=self.public_key.to_dict() if self.public_key else None, + # Note: httpx_client is not serialized as it's not serializable ) @classmethod diff --git a/integrations/langfuse/tests/test_tracing.py b/integrations/langfuse/tests/test_tracing.py index d4815fc9c..06f65e72c 100644 --- a/integrations/langfuse/tests/test_tracing.py +++ b/integrations/langfuse/tests/test_tracing.py @@ -10,6 +10,7 @@ from haystack.dataclasses import ChatMessage from haystack.utils import Secret from requests.auth import HTTPBasicAuth +import httpx from haystack_integrations.components.connectors.langfuse import LangfuseConnector from haystack_integrations.components.generators.anthropic import AnthropicChatGenerator @@ -49,6 +50,27 @@ def pipeline_with_secrets(llm_class, expected_trace): return pipe +@pytest.fixture +def pipeline_with_custom_client(llm_class, expected_trace): + """Pipeline factory using custom httpx client for Langfuse""" + pipe = Pipeline() + custom_client = httpx.Client(timeout=30.0) # Custom timeout of 30 seconds + pipe.add_component( + "tracer", + LangfuseConnector( + name=f"Chat example - {expected_trace}", + public=True, + secret_key=Secret.from_env_var("LANGFUSE_SECRET_KEY"), + public_key=Secret.from_env_var("LANGFUSE_PUBLIC_KEY"), + httpx_client=custom_client, + ), + ) + pipe.add_component("prompt_builder", ChatPromptBuilder()) + pipe.add_component("llm", llm_class()) + pipe.connect("prompt_builder.prompt", "llm.messages") + return pipe + + @pytest.mark.integration @pytest.mark.parametrize( "llm_class, env_var, expected_trace", @@ -58,7 +80,9 @@ def pipeline_with_secrets(llm_class, expected_trace): (CohereChatGenerator, "COHERE_API_KEY", "Cohere"), ], ) -@pytest.mark.parametrize("pipeline_fixture", ["pipeline_with_env_vars", "pipeline_with_secrets"]) +@pytest.mark.parametrize( + "pipeline_fixture", ["pipeline_with_env_vars", "pipeline_with_secrets", "pipeline_with_custom_client"] +) def test_tracing_integration(llm_class, env_var, expected_trace, pipeline_fixture, request): if not all([os.environ.get("LANGFUSE_SECRET_KEY"), os.environ.get("LANGFUSE_PUBLIC_KEY"), os.environ.get(env_var)]): pytest.skip(f"Missing required environment variables: LANGFUSE_SECRET_KEY, LANGFUSE_PUBLIC_KEY, or {env_var}") From 216383824516cd970954dee3bcded00f6b733790 Mon Sep 17 00:00:00 2001 From: HaystackBot Date: Tue, 21 Jan 2025 12:31:32 +0000 Subject: [PATCH 202/229] Update the changelog --- integrations/langfuse/CHANGELOG.md | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/integrations/langfuse/CHANGELOG.md b/integrations/langfuse/CHANGELOG.md index 3eb8ddba6..414e44e41 100644 --- a/integrations/langfuse/CHANGELOG.md +++ b/integrations/langfuse/CHANGELOG.md @@ -1,5 +1,16 @@ # Changelog +## [integrations/langfuse-v0.7.0] - 2025-01-21 + +### 🚀 Features + +- LangfuseConnector - add httpx.Client init param (#1308) + +### 🐛 Bug Fixes + +- End langfuse generation spans properly (#1301) + + ## [integrations/langfuse-v0.6.4] - 2025-01-17 ### 🚀 Features From 7a939d7b16c45ecf382d59bc6a9ab53a6aac8341 Mon Sep 17 00:00:00 2001 From: Stefano Fiorucci Date: Tue, 21 Jan 2025 16:49:14 +0100 Subject: [PATCH 203/229] chore: Bedrock - pin `transformers!=4.48.*` (#1306) * pin transformers!=4.48 * fix --- integrations/amazon_bedrock/pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/integrations/amazon_bedrock/pyproject.toml b/integrations/amazon_bedrock/pyproject.toml index ad9754a33..7044eb453 100644 --- a/integrations/amazon_bedrock/pyproject.toml +++ b/integrations/amazon_bedrock/pyproject.toml @@ -23,7 +23,7 @@ classifiers = [ "Programming Language :: Python :: Implementation :: CPython", "Programming Language :: Python :: Implementation :: PyPy", ] -dependencies = ["haystack-ai", "boto3>=1.28.57", "transformers!=4.48.0"] +dependencies = ["haystack-ai", "boto3>=1.28.57", "transformers!=4.48.*"] [project.urls] Documentation = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/amazon_bedrock#readme" From a5bdb763cbae87265b8f2154a49e2dada9c7a5c3 Mon Sep 17 00:00:00 2001 From: HaystackBot Date: Tue, 21 Jan 2025 15:51:19 +0000 Subject: [PATCH 204/229] Update the changelog --- integrations/amazon_bedrock/CHANGELOG.md | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/integrations/amazon_bedrock/CHANGELOG.md b/integrations/amazon_bedrock/CHANGELOG.md index 368b14f9e..7ba6b422c 100644 --- a/integrations/amazon_bedrock/CHANGELOG.md +++ b/integrations/amazon_bedrock/CHANGELOG.md @@ -1,5 +1,12 @@ # Changelog +## [integrations/amazon_bedrock-v2.1.3] - 2025-01-21 + +### 🧹 Chores + +- Bedrock - pin `transformers!=4.48.*` (#1306) + + ## [integrations/amazon_bedrock-v2.1.2] - 2025-01-20 ### 🌀 Miscellaneous From 01c538570750872a944871aad101afb7172ed64a Mon Sep 17 00:00:00 2001 From: Corentin Meyer Date: Thu, 23 Jan 2025 13:33:49 +0100 Subject: [PATCH 205/229] feat(AWS Bedrock): Add Cohere Reranker (#1291) * Amazon Bedrock: Add Cohere Rerank model * Run Lint * Remove changes to CHANGELOG.md * Remove var from serialization test * # noqa: B008 fix test lint, yada yada * adding BedrockRanker to pydoc --------- Co-authored-by: David S. Batista Co-authored-by: Stefano Fiorucci --- .../examples/bedrock_ranker_example.py | 32 +++ integrations/amazon_bedrock/pydoc/config.yml | 9 +- .../rankers/amazon_bedrock/__init__.py | 3 + .../rankers/amazon_bedrock/ranker.py | 233 ++++++++++++++++++ .../amazon_bedrock/tests/test_ranker.py | 103 ++++++++ 5 files changed, 376 insertions(+), 4 deletions(-) create mode 100644 integrations/amazon_bedrock/examples/bedrock_ranker_example.py create mode 100644 integrations/amazon_bedrock/src/haystack_integrations/components/rankers/amazon_bedrock/__init__.py create mode 100644 integrations/amazon_bedrock/src/haystack_integrations/components/rankers/amazon_bedrock/ranker.py create mode 100644 integrations/amazon_bedrock/tests/test_ranker.py diff --git a/integrations/amazon_bedrock/examples/bedrock_ranker_example.py b/integrations/amazon_bedrock/examples/bedrock_ranker_example.py new file mode 100644 index 000000000..b72805c12 --- /dev/null +++ b/integrations/amazon_bedrock/examples/bedrock_ranker_example.py @@ -0,0 +1,32 @@ +import os + +from haystack import Document +from haystack.utils import Secret + +from haystack_integrations.components.rankers.amazon_bedrock import BedrockRanker + +# Set up AWS credentials +# You can also set these as environment variables +aws_profile_name = os.environ.get("AWS_PROFILE") or "default" +aws_region_name = os.environ.get("AWS_DEFAULT_REGION") or "eu-central-1" +# Initialize the BedrockRanker with AWS credentials +ranker = BedrockRanker( + model="cohere.rerank-v3-5:0", + top_k=2, + aws_profile_name=Secret.from_token(aws_profile_name), + aws_region_name=Secret.from_token(aws_region_name), +) + +# Create some sample documents +docs = [ + Document(content="Paris is the capital of France."), + Document(content="Berlin is the capital of Germany."), + Document(content="London is the capital of the United Kingdom."), + Document(content="Rome is the capital of Italy."), +] + +# Define a query +query = "What is the capital of Germany?" + +# Run the ranker +output = ranker.run(query=query, documents=docs) diff --git a/integrations/amazon_bedrock/pydoc/config.yml b/integrations/amazon_bedrock/pydoc/config.yml index 6cb05d6f3..e6d6eba78 100644 --- a/integrations/amazon_bedrock/pydoc/config.yml +++ b/integrations/amazon_bedrock/pydoc/config.yml @@ -2,13 +2,14 @@ loaders: - type: haystack_pydoc_tools.loaders.CustomPythonLoader search_path: [../src] modules: [ + "haystack_integrations.common.amazon_bedrock.errors", + "haystack_integrations.components.embedders.amazon_bedrock.document_embedder", + "haystack_integrations.components.embedders.amazon_bedrock.text_embedder", "haystack_integrations.components.generators.amazon_bedrock.generator", "haystack_integrations.components.generators.amazon_bedrock.adapters", - "haystack_integrations.common.amazon_bedrock.errors", - "haystack_integrations.components.generators.amazon_bedrock.handlers", "haystack_integrations.components.generators.amazon_bedrock.chat.chat_generator", - "haystack_integrations.components.embedders.amazon_bedrock.text_embedder", - "haystack_integrations.components.embedders.amazon_bedrock.document_embedder", + "haystack_integrations.components.generators.amazon_bedrock.handlers", + "haystack_integrations.components.rankers.amazon_bedrock.ranker", ] ignore_when_discovered: ["__init__"] processors: diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/rankers/amazon_bedrock/__init__.py b/integrations/amazon_bedrock/src/haystack_integrations/components/rankers/amazon_bedrock/__init__.py new file mode 100644 index 000000000..8979d1f1b --- /dev/null +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/rankers/amazon_bedrock/__init__.py @@ -0,0 +1,3 @@ +from .ranker import BedrockRanker + +__all__ = ["BedrockRanker"] diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/rankers/amazon_bedrock/ranker.py b/integrations/amazon_bedrock/src/haystack_integrations/components/rankers/amazon_bedrock/ranker.py new file mode 100644 index 000000000..735e50b59 --- /dev/null +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/rankers/amazon_bedrock/ranker.py @@ -0,0 +1,233 @@ +import json +import logging +from typing import Any, Dict, List, Optional + +from botocore.exceptions import ClientError +from haystack import Document, component, default_from_dict, default_to_dict +from haystack.utils import Secret, deserialize_secrets_inplace + +from haystack_integrations.common.amazon_bedrock.errors import ( + AmazonBedrockConfigurationError, + AmazonBedrockInferenceError, +) +from haystack_integrations.common.amazon_bedrock.utils import get_aws_session + +logger = logging.getLogger(__name__) + +MAX_NUM_DOCS_FOR_BEDROCK_RANKER = 1000 + + +@component +class BedrockRanker: + """ + Ranks Documents based on their similarity to the query using Amazon Bedrock's Cohere Rerank model. + + Documents are indexed from most to least semantically relevant to the query. + + Usage example: + ```python + from haystack import Document + from haystack.utils import Secret + from haystack_integrations.components.rankers.amazon_bedrock import BedrockRanker + + ranker = BedrockRanker(model="cohere.rerank-v3-5:0", top_k=2, aws_region_name=Secret.from_token("eu-central-1")) + + docs = [Document(content="Paris"), Document(content="Berlin")] + query = "What is the capital of germany?" + output = ranker.run(query=query, documents=docs) + docs = output["documents"] + ``` + + BedrockRanker uses AWS for authentication. You can use the AWS CLI to authenticate through your IAM. + For more information on setting up an IAM identity-based policy, see [Amazon Bedrock documentation] + (https://docs.aws.amazon.com/bedrock/latest/userguide/security_iam_id-based-policy-examples.html). + + If the AWS environment is configured correctly, the AWS credentials are not required as they're loaded + automatically from the environment or the AWS configuration file. + If the AWS environment is not configured, set `aws_access_key_id`, `aws_secret_access_key`, + and `aws_region_name` as environment variables or pass them as + [Secret](https://docs.haystack.deepset.ai/v2.0/docs/secret-management) arguments. Make sure the region you set + supports Amazon Bedrock. + """ + + def __init__( + self, + model: str = "cohere.rerank-v3-5:0", + top_k: int = 10, + aws_access_key_id: Optional[Secret] = Secret.from_env_var(["AWS_ACCESS_KEY_ID"], strict=False), # noqa: B008 + aws_secret_access_key: Optional[Secret] = Secret.from_env_var( # noqa: B008 + ["AWS_SECRET_ACCESS_KEY"], strict=False + ), + aws_session_token: Optional[Secret] = Secret.from_env_var(["AWS_SESSION_TOKEN"], strict=False), # noqa: B008 + aws_region_name: Optional[Secret] = Secret.from_env_var(["AWS_DEFAULT_REGION"], strict=False), # noqa: B008 + aws_profile_name: Optional[Secret] = Secret.from_env_var(["AWS_PROFILE"], strict=False), # noqa: B008 + max_chunks_per_doc: Optional[int] = None, + meta_fields_to_embed: Optional[List[str]] = None, + meta_data_separator: str = "\n", + ): + if not model: + msg = "'model' cannot be None or empty string" + raise ValueError(msg) + """ + Creates an instance of the 'BedrockRanker'. + + :param model: Amazon Bedrock model name for Cohere Rerank. Default is "cohere.rerank-v3-5:0". + :param top_k: The maximum number of documents to return. + :param aws_access_key_id: AWS access key ID. + :param aws_secret_access_key: AWS secret access key. + :param aws_session_token: AWS session token. + :param aws_region_name: AWS region name. + :param aws_profile_name: AWS profile name. + :param max_chunks_per_doc: If your document exceeds 512 tokens, this determines the maximum number of + chunks a document can be split into. If `None`, the default of 10 is used. + Note: This parameter is not currently used in the implementation but is included for future compatibility. + :param meta_fields_to_embed: List of meta fields that should be concatenated + with the document content for reranking. + :param meta_data_separator: Separator used to concatenate the meta fields + to the Document content. + """ + self.model_name = model + self.aws_access_key_id = aws_access_key_id + self.aws_secret_access_key = aws_secret_access_key + self.aws_session_token = aws_session_token + self.aws_region_name = aws_region_name + self.aws_profile_name = aws_profile_name + self.top_k = top_k + self.max_chunks_per_doc = max_chunks_per_doc + self.meta_fields_to_embed = meta_fields_to_embed or [] + self.meta_data_separator = meta_data_separator + + def resolve_secret(secret: Optional[Secret]) -> Optional[str]: + return secret.resolve_value() if secret else None + + try: + session = get_aws_session( + aws_access_key_id=resolve_secret(aws_access_key_id), + aws_secret_access_key=resolve_secret(aws_secret_access_key), + aws_session_token=resolve_secret(aws_session_token), + aws_region_name=resolve_secret(aws_region_name), + aws_profile_name=resolve_secret(aws_profile_name), + ) + self._bedrock_client = session.client("bedrock-runtime") + except Exception as exception: + msg = ( + "Could not connect to Amazon Bedrock. Make sure the AWS environment is configured correctly. " + "See https://boto3.amazonaws.com/v1/documentation/api/latest/guide/quickstart.html#configuration" + ) + raise AmazonBedrockConfigurationError(msg) from exception + + def to_dict(self) -> Dict[str, Any]: + """ + Serializes the component to a dictionary. + + :returns: + Dictionary with serialized data. + """ + return default_to_dict( + self, + model=self.model_name, + aws_access_key_id=self.aws_access_key_id.to_dict() if self.aws_access_key_id else None, + aws_secret_access_key=self.aws_secret_access_key.to_dict() if self.aws_secret_access_key else None, + aws_session_token=self.aws_session_token.to_dict() if self.aws_session_token else None, + aws_region_name=self.aws_region_name.to_dict() if self.aws_region_name else None, + aws_profile_name=self.aws_profile_name.to_dict() if self.aws_profile_name else None, + top_k=self.top_k, + max_chunks_per_doc=self.max_chunks_per_doc, + meta_fields_to_embed=self.meta_fields_to_embed, + meta_data_separator=self.meta_data_separator, + ) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "BedrockRanker": + """ + Deserializes the component from a dictionary. + + :param data: + The dictionary to deserialize from. + :returns: + The deserialized component. + """ + deserialize_secrets_inplace( + data["init_parameters"], + ["aws_access_key_id", "aws_secret_access_key", "aws_session_token", "aws_region_name", "aws_profile_name"], + ) + return default_from_dict(cls, data) + + def _prepare_bedrock_input_docs(self, documents: List[Document]) -> List[str]: + """ + Prepare the input by concatenating the document text with the metadata fields specified. + :param documents: The list of Document objects. + + :return: A list of strings to be given as input to Bedrock model. + """ + concatenated_input_list = [] + for doc in documents: + meta_values_to_embed = [ + str(doc.meta[key]) for key in self.meta_fields_to_embed if key in doc.meta and doc.meta.get(key) + ] + concatenated_input = self.meta_data_separator.join([*meta_values_to_embed, doc.content or ""]) + concatenated_input_list.append(concatenated_input) + + return concatenated_input_list + + @component.output_types(documents=List[Document]) + def run(self, query: str, documents: List[Document], top_k: Optional[int] = None): + """ + Use the Amazon Bedrock Cohere Reranker to re-rank the list of documents based on the query. + + :param query: + Query string. + :param documents: + List of Documents. + :param top_k: + The maximum number of Documents you want the Ranker to return. + :returns: + A dictionary with the following keys: + - `documents`: List of Documents most similar to the given query in descending order of similarity. + + :raises ValueError: If `top_k` is not > 0. + """ + top_k = top_k or self.top_k + if top_k <= 0: + msg = f"top_k must be > 0, but got {top_k}" + raise ValueError(msg) + + bedrock_input_docs = self._prepare_bedrock_input_docs(documents) + if len(bedrock_input_docs) > MAX_NUM_DOCS_FOR_BEDROCK_RANKER: + logger.warning( + f"The Amazon Bedrock reranking endpoint only supports {MAX_NUM_DOCS_FOR_BEDROCK_RANKER} documents.\ + The number of documents has been truncated to {MAX_NUM_DOCS_FOR_BEDROCK_RANKER} \ + from {len(bedrock_input_docs)}." + ) + bedrock_input_docs = bedrock_input_docs[:MAX_NUM_DOCS_FOR_BEDROCK_RANKER] + + # Prepare the request body for Amazon Bedrock + request_body = {"documents": bedrock_input_docs, "query": query, "top_n": top_k, "api_version": 2} + + try: + # Make the API call to Amazon Bedrock + response = self._bedrock_client.invoke_model(modelId=self.model_name, body=json.dumps(request_body)) + + # Parse the response + response_body = json.loads(response["body"].read()) + results = response_body["results"] + + # Sort documents based on the reranking results + sorted_docs = [] + for result in results: + idx = result["index"] + score = result["relevance_score"] + doc = documents[idx] + doc.score = score + sorted_docs.append(doc) + + return {"documents": sorted_docs} + except ClientError as exception: + msg = f"Could not inference Amazon Bedrock model {self.model_name} due: {exception}" + raise AmazonBedrockInferenceError(msg) from exception + except KeyError as e: + msg = f"Unexpected response format from Amazon Bedrock: {e!s}" + raise AmazonBedrockInferenceError(msg) from e + except Exception as e: + msg = f"Error during Amazon Bedrock API call: {e!s}" + raise AmazonBedrockInferenceError(msg) from e diff --git a/integrations/amazon_bedrock/tests/test_ranker.py b/integrations/amazon_bedrock/tests/test_ranker.py new file mode 100644 index 000000000..f648ac551 --- /dev/null +++ b/integrations/amazon_bedrock/tests/test_ranker.py @@ -0,0 +1,103 @@ +from unittest.mock import MagicMock, patch + +import pytest +from haystack import Document +from haystack.utils import Secret + +from haystack_integrations.common.amazon_bedrock.errors import ( + AmazonBedrockInferenceError, +) +from haystack_integrations.components.rankers.amazon_bedrock import BedrockRanker + + +@pytest.fixture +def mock_aws_session(): + with patch("haystack_integrations.components.rankers.amazon_bedrock.ranker.get_aws_session") as mock_session: + mock_client = MagicMock() + mock_session.return_value.client.return_value = mock_client + yield mock_client + + +def test_bedrock_ranker_initialization(mock_aws_session): + ranker = BedrockRanker( + model="cohere.rerank-v3-5:0", + top_k=2, + aws_access_key_id=Secret.from_token("test_access_key"), + aws_secret_access_key=Secret.from_token("test_secret_key"), + aws_region_name=Secret.from_token("us-east-1"), + ) + assert ranker.model_name == "cohere.rerank-v3-5:0" + assert ranker.top_k == 2 + + +def test_bedrock_ranker_run(mock_aws_session): + ranker = BedrockRanker( + model="cohere.rerank-v3-5:0", + top_k=2, + aws_access_key_id=Secret.from_token("test_access_key"), + aws_secret_access_key=Secret.from_token("test_secret_key"), + aws_region_name=Secret.from_token("us-east-1"), + ) + + mock_response = { + "body": MagicMock( + read=MagicMock( + return_value=b'{"results": [{"index": 0, "relevance_score": 0.9},' + b' {"index": 1, "relevance_score": 0.7}]}' + ) + ) + } + mock_aws_session.invoke_model.return_value = mock_response + + docs = [Document(content="Test document 1"), Document(content="Test document 2")] + result = ranker.run(query="test query", documents=docs) + + assert len(result["documents"]) == 2 + assert result["documents"][0].score == 0.9 + assert result["documents"][1].score == 0.7 + + +@pytest.mark.integration +def test_bedrock_ranker_live_run(): + ranker = BedrockRanker( + model="cohere.rerank-v3-5:0", + top_k=2, + aws_region_name=Secret.from_token("eu-central-1"), + ) + + docs = [Document(content="Test document 1"), Document(content="Test document 2")] + result = ranker.run(query="test query", documents=docs) + assert len(result["documents"]) == 2 + assert isinstance(result["documents"][0].score, float) + + +def test_bedrock_ranker_run_inference_error(mock_aws_session): + ranker = BedrockRanker( + model="cohere.rerank-v3-5:0", + top_k=2, + aws_access_key_id=Secret.from_token("test_access_key"), + aws_secret_access_key=Secret.from_token("test_secret_key"), + aws_region_name=Secret.from_token("us-east-1"), + ) + + mock_aws_session.invoke_model.side_effect = Exception("Inference error") + + docs = [Document(content="Test document 1"), Document(content="Test document 2")] + with pytest.raises(AmazonBedrockInferenceError): + ranker.run(query="test query", documents=docs) + + +def test_bedrock_ranker_serialization(mock_aws_session): + ranker = BedrockRanker( + model="cohere.rerank-v3-5:0", + top_k=2, + ) + + serialized = ranker.to_dict() + assert serialized["init_parameters"]["model"] == "cohere.rerank-v3-5:0" + assert serialized["init_parameters"]["top_k"] == 2 + + deserialized = BedrockRanker.from_dict(serialized) + assert isinstance(deserialized, BedrockRanker) + assert deserialized.model_name == "cohere.rerank-v3-5:0" + assert deserialized.top_k == 2 From 969dc2b09745309f9e7233e95cb37a4131250561 Mon Sep 17 00:00:00 2001 From: Stefano Fiorucci Date: Thu, 23 Jan 2025 15:10:51 +0100 Subject: [PATCH 206/229] feat!: Google AI - support for Tool + general refactoring (#1316) * progress * still broken * massive refactoring * more refactoring * tests * docstrings * refinements * feedback --- integrations/google_ai/pyproject.toml | 10 +- .../generators/google_ai/chat/gemini.py | 383 +++++----- .../tests/generators/chat/test_chat_gemini.py | 673 +++++++++++------- 3 files changed, 596 insertions(+), 470 deletions(-) diff --git a/integrations/google_ai/pyproject.toml b/integrations/google_ai/pyproject.toml index 4da1db297..c717f1830 100644 --- a/integrations/google_ai/pyproject.toml +++ b/integrations/google_ai/pyproject.toml @@ -5,7 +5,7 @@ build-backend = "hatchling.build" [project] name = "google-ai-haystack" dynamic = ["version"] -description = 'Use models like Gemini via Makersuite' +description = 'Use models like Gemini via Google AI Studio' readme = "README.md" requires-python = ">=3.9" license = "Apache-2.0" @@ -46,6 +46,7 @@ dependencies = [ "pytest", "pytest-rerunfailures", "haystack-pydoc-tools", + "jsonschema", # needed for Tool ] [tool.hatch.envs.default.scripts] test = "pytest {args:tests}" @@ -156,3 +157,10 @@ module = [ "numpy.*", ] ignore_missing_imports = true + +[tool.pytest.ini_options] +addopts = "--strict-markers" +markers = [ + "integration: integration tests", +] +log_cli = true diff --git a/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/chat/gemini.py b/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/chat/gemini.py index 2addaca7a..30d92fd68 100644 --- a/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/chat/gemini.py +++ b/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/chat/gemini.py @@ -1,25 +1,71 @@ -import json import logging from typing import Any, Callable, Dict, List, Optional, Union import google.generativeai as genai from google.ai.generativelanguage import Content, Part -from google.ai.generativelanguage import Tool as ToolProto from google.generativeai import GenerationConfig, GenerativeModel -from google.generativeai.types import GenerateContentResponse, HarmBlockThreshold, HarmCategory, Tool +from google.generativeai.types import ( + FunctionDeclaration, + GenerateContentResponse, + HarmBlockThreshold, + HarmCategory, + content_types, +) from haystack.core.component import component from haystack.core.serialization import default_from_dict, default_to_dict -from haystack.dataclasses import ByteStream, StreamingChunk -from haystack.dataclasses.chat_message import ChatMessage, ChatRole +from haystack.dataclasses import StreamingChunk +from haystack.dataclasses.chat_message import ChatMessage, ChatRole, ToolCall +from haystack.tools import Tool, _check_duplicate_tool_names, deserialize_tools_inplace from haystack.utils import Secret, deserialize_callable, deserialize_secrets_inplace, serialize_callable logger = logging.getLogger(__name__) +def _convert_chatmessage_to_google_content(message: ChatMessage) -> Content: + """ + Converts a Haystack `ChatMessage` to a Google AI `Content` object. + System messages are not supported. + + :param message: The Haystack `ChatMessage` to convert. + :returns: The Google AI `Content` object. + """ + + if message.is_from(ChatRole.SYSTEM): + msg = "This function does not support system messages." + raise ValueError(msg) + + texts = message.texts + tool_calls = message.tool_calls + tool_call_results = message.tool_call_results + + if not texts and not tool_calls and not tool_call_results: + msg = "A `ChatMessage` must contain at least one `TextContent`, `ToolCall`, or `ToolCallResult`." + raise ValueError(msg) + + if len(texts) + len(tool_call_results) > 1: + msg = "A `ChatMessage` can only contain one `TextContent` or one `ToolCallResult`." + raise ValueError(msg) + + role = "model" if message.is_from(ChatRole.ASSISTANT) else "user" + + if tool_call_results: + part = Part( + function_response=genai.protos.FunctionResponse( + name=tool_call_results[0].origin.tool_name, response={"result": tool_call_results[0].result} + ) + ) + return Content(parts=[part], role=role) + + parts = ([Part(text=texts[0])] if texts else []) + [ + Part(function_call=genai.protos.FunctionCall(name=tc.tool_name, args=tc.arguments)) for tc in tool_calls + ] + return Content(parts=parts, role=role) + + @component class GoogleAIGeminiChatGenerator: """ - Completes chats using multimodal Gemini models through Google AI Studio. + Completes chats using Gemini models through Google AI Studio. It uses the [`ChatMessage`](https://docs.haystack.deepset.ai/docs/data-classes#chatmessage) dataclass to interact with the model. @@ -32,7 +78,7 @@ class GoogleAIGeminiChatGenerator: from haystack_integrations.components.generators.google_ai import GoogleAIGeminiChatGenerator - gemini_chat = GoogleAIGeminiChatGenerator(model="gemini-pro", api_key=Secret.from_token("")) + gemini_chat = GoogleAIGeminiChatGenerator(model="gemini-1.5-flash", api_key=Secret.from_token("")) messages = [ChatMessage.from_user("What is the most interesting thing you know?")] res = gemini_chat.run(messages=messages) @@ -49,51 +95,40 @@ class GoogleAIGeminiChatGenerator: #### With function calling: ```python + from typing import Annotated from haystack.utils import Secret from haystack.dataclasses.chat_message import ChatMessage - from google.ai.generativelanguage import FunctionDeclaration, Tool + from haystack.components.tools import ToolInvoker + from haystack.tools import create_tool_from_function from haystack_integrations.components.generators.google_ai import GoogleAIGeminiChatGenerator - # Example function to get the current weather - def get_current_weather(location: str, unit: str = "celsius") -> str: - # Call a weather API and return some text - ... - - # Define the function interface - get_current_weather_func = FunctionDeclaration( - name="get_current_weather", - description="Get the current weather in a given location", - parameters={ - "type": "object", - "properties": { - "location": {"type": "string", "description": "The city and state, e.g. San Francisco, CA"}, - "unit": { - "type": "string", - "enum": [ - "celsius", - "fahrenheit", - ], - }, - }, - "required": ["location"], - }, - ) - tool = Tool([get_current_weather_func]) + # example function to get the current weather + def get_current_weather( + location: Annotated[str, "The city for which to get the weather, e.g. 'San Francisco'"] = "Munich", + unit: Annotated[str, "The unit for the temperature, e.g. 'celsius'"] = "celsius", + ) -> str: + return f"The weather in {location} is sunny. The temperature is 20 {unit}." - messages = [ChatMessage.from_user("What is the most interesting thing you know?")] + tool = create_tool_from_function(get_current_weather) + tool_invoker = ToolInvoker(tools=[tool]) - gemini_chat = GoogleAIGeminiChatGenerator(model="gemini-pro", api_key=Secret.from_token(""), - tools=[tool]) + gemini_chat = GoogleAIGeminiChatGenerator( + model="gemini-2.0-flash-exp", + api_key=Secret.from_token(""), + tools=[tool], + ) + user_message = [ChatMessage.from_user("What is the temperature in celsius in Berlin?")] + replies = gemini_chat.run(messages=user_message)["replies"] + print(replies[0].tool_calls) - messages = [ChatMessage.from_user("What is the temperature in celsius in Berlin?")] - res = gemini_chat.run(messages=messages) + # actually invoke the tool + tool_messages = tool_invoker.run(messages=replies)["tool_messages"] + messages = user_message + replies + tool_messages - weather = get_current_weather(**json.loads(res["replies"][0].text)) - messages += res["replies"] + [ChatMessage.from_function(content=weather, name="get_current_weather")] - res = gemini_chat.run(messages=messages) - for reply in res["replies"]: - print(reply.text) + # transform the tool call result into a human readable message + final_replies = gemini_chat.run(messages=messages)["replies"] + print(final_replies[0].text) ``` """ @@ -110,34 +145,37 @@ def __init__( """ Initializes a `GoogleAIGeminiChatGenerator` instance. - To get an API key, visit: https://makersuite.google.com + To get an API key, visit: https://aistudio.google.com/ :param api_key: Google AI Studio API key. To get a key, - see [Google AI Studio](https://makersuite.google.com). + see [Google AI Studio](https://aistudio.google.com/). :param model: Name of the model to use. For available models, see https://ai.google.dev/gemini-api/docs/models/gemini. :param generation_config: The generation configuration to use. This can either be a `GenerationConfig` object or a dictionary of parameters. For available parameters, see - [the `GenerationConfig` API reference](https://ai.google.dev/api/python/google/generativeai/GenerationConfig). + [the API reference](https://ai.google.dev/api/generate-content). :param safety_settings: The safety settings to use. A dictionary with `HarmCategory` as keys and `HarmBlockThreshold` as values. - For more information, see [the API reference](https://ai.google.dev/api) - :param tools: A list of Tool objects that can be used for [Function calling](https://ai.google.dev/docs/function_calling). + For more information, see [the API reference](https://ai.google.dev/api/generate-content) + :param tools: + A list of tools for which the model can prepare calls. :param streaming_callback: A callback function that is called when a new token is received from the stream. The callback function accepts StreamingChunk as an argument. """ genai.configure(api_key=api_key.resolve_value()) + _check_duplicate_tool_names(tools) self._api_key = api_key self._model_name = model self._generation_config = generation_config self._safety_settings = safety_settings self._tools = tools - self._model = GenerativeModel(self._model_name, tools=self._tools) + self._model = GenerativeModel(self._model_name) self._streaming_callback = streaming_callback - def _generation_config_to_dict(self, config: Union[GenerationConfig, Dict[str, Any]]) -> Dict[str, Any]: + @staticmethod + def _generation_config_to_dict(config: Union[GenerationConfig, Dict[str, Any]]) -> Dict[str, Any]: if isinstance(config, dict): return config return { @@ -164,18 +202,9 @@ def to_dict(self) -> Dict[str, Any]: model=self._model_name, generation_config=self._generation_config, safety_settings=self._safety_settings, - tools=self._tools, + tools=[tool.to_dict() for tool in self._tools] if self._tools else None, streaming_callback=callback_name, ) - if (tools := data["init_parameters"].get("tools")) is not None: - data["init_parameters"]["tools"] = [] - for tool in tools: - if isinstance(tool, Tool): - # There are multiple Tool types in the Google lib, one that is a protobuf class and - # another is a simple Python class. They have a similar structure but the Python class - # can't be easily serializated to a dict. We need to convert it to a protobuf class first. - tool = tool.to_proto() # noqa: PLW2901 - data["init_parameters"]["tools"].append(ToolProto.serialize(tool)) if (generation_config := data["init_parameters"].get("generation_config")) is not None: data["init_parameters"]["generation_config"] = self._generation_config_to_dict(generation_config) if (safety_settings := data["init_parameters"].get("safety_settings")) is not None: @@ -193,17 +222,7 @@ def from_dict(cls, data: Dict[str, Any]) -> "GoogleAIGeminiChatGenerator": Deserialized component. """ deserialize_secrets_inplace(data["init_parameters"], keys=["api_key"]) - - if (tools := data["init_parameters"].get("tools")) is not None: - deserialized_tools = [] - for tool in tools: - # Tools are always serialized as a protobuf class, so we need to deserialize them first - # to be able to convert them to the Python class. - proto = ToolProto.deserialize(tool) - deserialized_tools.append( - Tool(function_declarations=proto.function_declarations, code_execution=proto.code_execution) - ) - data["init_parameters"]["tools"] = deserialized_tools + deserialize_tools_inplace(data["init_parameters"], key="tools") if (generation_config := data["init_parameters"].get("generation_config")) is not None: data["init_parameters"]["generation_config"] = GenerationConfig(**generation_config) if (safety_settings := data["init_parameters"].get("safety_settings")) is not None: @@ -214,81 +233,29 @@ def from_dict(cls, data: Dict[str, Any]) -> "GoogleAIGeminiChatGenerator": data["init_parameters"]["streaming_callback"] = deserialize_callable(serialized_callback_handler) return default_from_dict(cls, data) - def _convert_part(self, part: Union[str, ByteStream, Part]) -> Part: - if isinstance(part, str): - converted_part = Part() - converted_part.text = part - return converted_part - elif isinstance(part, ByteStream): - converted_part = Part() - converted_part.inline_data.data = part.data - converted_part.inline_data.mime_type = part.mime_type - return converted_part - elif isinstance(part, Part): - return part - else: - msg = f"Unsupported type {type(part)} for part {part}" - raise ValueError(msg) - - def _message_to_part(self, message: ChatMessage) -> Part: - if message.is_from(ChatRole.ASSISTANT) and message.name: - p = Part() - p.function_call.name = message.name - p.function_call.args = {} - for k, v in json.loads(message.text).items(): - p.function_call.args[k] = v - return p - elif message.is_from(ChatRole.SYSTEM) or message.is_from(ChatRole.ASSISTANT): - p = Part() - p.text = message.text - return p - elif "FUNCTION" in ChatRole._member_names_ and message.is_from(ChatRole.FUNCTION): - p = Part() - p.function_response.name = message.name - p.function_response.response = message.text - return p - elif message.is_from(ChatRole.TOOL): - p = Part() - p.function_response.name = message.tool_call_result.origin.tool_name - p.function_response.response = message.tool_call_result.result - return p - elif message.is_from(ChatRole.USER): - return self._convert_part(message.text) - - def _message_to_content(self, message: ChatMessage) -> Content: - if message.is_from(ChatRole.ASSISTANT) and message.name: - part = Part() - part.function_call.name = message.name - part.function_call.args = {} - for k, v in json.loads(message.text).items(): - part.function_call.args[k] = v - elif message.is_from(ChatRole.SYSTEM) or message.is_from(ChatRole.ASSISTANT): - part = Part() - part.text = message.text - elif "FUNCTION" in ChatRole._member_names_ and message.is_from(ChatRole.FUNCTION): - part = Part() - part.function_response.name = message.name - part.function_response.response = message.text - elif message.is_from(ChatRole.USER): - part = self._convert_part(message.text) - elif message.is_from(ChatRole.TOOL): - part = Part() - part.function_response.name = message.tool_call_result.origin.tool_name - part.function_response.response = message.tool_call_result.result - else: - msg = f"Unsupported message role {message.role}" - raise ValueError(msg) - - role = "user" - if message.is_from(ChatRole.ASSISTANT) or message.is_from(ChatRole.SYSTEM): - role = "model" - return Content(parts=[part], role=role) + @staticmethod + def _convert_to_google_tool(tool: Tool) -> FunctionDeclaration: + """ + Converts a Haystack `Tool` to a Google AI `FunctionDeclaration` object. + + :param tool: The Haystack `Tool` to convert. + :returns: The Google AI `FunctionDeclaration` object. + """ + parameters = tool.parameters.copy() + + # Remove default values as Google API doesn't support them + for prop in parameters["properties"].values(): + prop.pop("default", None) + + return FunctionDeclaration(name=tool.name, description=tool.description, parameters=parameters) @component.output_types(replies=List[ChatMessage]) def run( self, messages: List[ChatMessage], streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, + *, + tools: Optional[List[Tool]] = None, ): """ Generates text based on the provided messages. @@ -297,98 +264,114 @@ def run( A list of `ChatMessage` instances, representing the input messages. :param streaming_callback: A callback function that is called when a new token is received from the stream. + :param tools: + A list of tools for which the model can prepare calls. If set, it will override the `tools` parameter set + during component initialization. :returns: A dictionary containing the following key: - `replies`: A list containing the generated responses as `ChatMessage` instances. """ streaming_callback = streaming_callback or self._streaming_callback - history = [self._message_to_content(m) for m in messages[:-1]] - session = self._model.start_chat(history=history) - new_message = self._message_to_part(messages[-1]) + tools = tools or self._tools + _check_duplicate_tool_names(tools) + google_tools = [self._convert_to_google_tool(tool) for tool in tools] if tools else None + + if messages[0].is_from(ChatRole.SYSTEM): + self._model._system_instruction = content_types.to_content(messages[0].text) + messages = messages[1:] + + google_messages = [_convert_chatmessage_to_google_content(m) for m in messages] + + session = self._model.start_chat(history=google_messages[:-1]) + res = session.send_message( - content=new_message, + content=google_messages[-1], generation_config=self._generation_config, safety_settings=self._safety_settings, stream=streaming_callback is not None, + tools=google_tools, ) - replies = self._get_stream_response(res, streaming_callback) if streaming_callback else self._get_response(res) + replies = ( + self._stream_response_and_convert_to_messages(res, streaming_callback) + if streaming_callback + else self._convert_response_to_messages(res) + ) return {"replies": replies} - def _get_response(self, response_body: GenerateContentResponse) -> List[ChatMessage]: + @staticmethod + def _convert_response_to_messages(response_body: GenerateContentResponse) -> List[ChatMessage]: """ - Extracts the responses from the Google AI response. + Converts the Google AI response to a list of `ChatMessage` instances. :param response_body: The response from Google AI request. - :returns: The extracted responses. + :returns: List of `ChatMessage` instances. """ - replies: List[ChatMessage] = [] metadata = response_body.to_dict() - # currently Google only supports one candidate and usage metadata reflects this - # this should be refactored when multiple candidates are supported - usage_metadata_openai_format = {} + # Only one candidate is supported for chat functionality + candidate = response_body.candidates[0] + candidate_metadata = metadata["candidates"][0] + candidate_metadata.pop("content", None) # we remove content from the metadata - usage_metadata = metadata.get("usage_metadata") - if usage_metadata: - usage_metadata_openai_format = { + # adapt usage metadata to OpenAI format + if usage_metadata := metadata.get("usage_metadata"): + candidate_metadata["usage"] = { "prompt_tokens": usage_metadata["prompt_token_count"], "completion_tokens": usage_metadata["candidates_token_count"], "total_tokens": usage_metadata["total_token_count"], } - for idx, candidate in enumerate(response_body.candidates): - candidate_metadata = metadata["candidates"][idx] - candidate_metadata.pop("content", None) # we remove content from the metadata - if usage_metadata_openai_format: - candidate_metadata["usage"] = usage_metadata_openai_format - - for part in candidate.content.parts: - if part.text != "": - replies.append(ChatMessage.from_assistant(part.text, meta=candidate_metadata)) - elif part.function_call: - candidate_metadata["function_call"] = part.function_call - new_message = ChatMessage.from_assistant( - json.dumps(dict(part.function_call.args)), meta=candidate_metadata + text = "" + tool_calls = [] + + for part in candidate.content.parts: + if part.text: + text += part.text + elif part.function_call: + tool_calls.append( + ToolCall( + tool_name=part.function_call.name, + arguments=dict(part.function_call.args), ) - try: - new_message.name = part.function_call.name - except AttributeError: - new_message._name = part.function_call.name - replies.append(new_message) - return replies - - def _get_stream_response( - self, stream: GenerateContentResponse, streaming_callback: Callable[[StreamingChunk], None] + ) + + return [ChatMessage.from_assistant(text=text or None, meta=candidate_metadata, tool_calls=tool_calls)] + + @staticmethod + def _stream_response_and_convert_to_messages( + stream: GenerateContentResponse, streaming_callback: Callable[[StreamingChunk], None] ) -> List[ChatMessage]: """ - Extracts the responses from the Google AI streaming response. + Streams the Google AI response and converts it to a list of `ChatMessage` instances. :param stream: The streaming response from the Google AI request. :param streaming_callback: The handler for the streaming response. - :returns: The extracted response with the content of all streaming chunks. + :returns: List of `ChatMessage` instances. """ - replies: List[ChatMessage] = [] + text = "" + tool_calls = [] + last_metadata = None + for chunk in stream: - content: Union[str, Dict[str, Any]] = "" - dict_chunk = chunk.to_dict() - metadata = dict(dict_chunk) # we copy and store the whole chunk as metadata in streaming calls - for candidate in dict_chunk["candidates"]: - for part in candidate["content"]["parts"]: - if "text" in part and part["text"] != "": - content = part["text"] - replies.append(ChatMessage.from_assistant(content, meta=metadata)) - elif "function_call" in part and len(part["function_call"]) > 0: - metadata["function_call"] = part["function_call"] - content = json.dumps(dict(part["function_call"]["args"])) - new_message = ChatMessage.from_assistant(content, meta=metadata) - try: - new_message.name = part["function_call"]["name"] - except AttributeError: - new_message._name = part["function_call"]["name"] - replies.append(new_message) - - streaming_callback(StreamingChunk(content=content, meta=metadata)) - return replies + chunk_dict = chunk.to_dict() + last_metadata = chunk_dict + # Only one candidate is supported for chat functionality + candidate = chunk_dict["candidates"][0] + + for part in candidate["content"]["parts"]: + if part.get("text"): + text += part["text"] + elif part.get("function_call"): + tool_calls.append( + ToolCall( + tool_name=part["function_call"]["name"], + arguments=dict(part["function_call"]["args"]), + ) + ) + + streaming_callback(StreamingChunk(content=text, meta=chunk_dict)) + + return [ChatMessage.from_assistant(text=text or None, meta=last_metadata, tool_calls=tool_calls)] diff --git a/integrations/google_ai/tests/generators/chat/test_chat_gemini.py b/integrations/google_ai/tests/generators/chat/test_chat_gemini.py index ce12d4a4d..254da4506 100644 --- a/integrations/google_ai/tests/generators/chat/test_chat_gemini.py +++ b/integrations/google_ai/tests/generators/chat/test_chat_gemini.py @@ -1,308 +1,443 @@ import json import os +from typing import Annotated, Literal from unittest.mock import patch +import google.generativeai as genai import pytest +from google.ai.generativelanguage import Part from google.generativeai import GenerationConfig, GenerativeModel -from google.generativeai.types import FunctionDeclaration, HarmBlockThreshold, HarmCategory, Tool +from google.generativeai.types import HarmBlockThreshold, HarmCategory +from haystack import Pipeline from haystack.dataclasses import StreamingChunk -from haystack.dataclasses.chat_message import ChatMessage, ChatRole - -from haystack_integrations.components.generators.google_ai import GoogleAIGeminiChatGenerator - -GET_CURRENT_WEATHER_FUNC = FunctionDeclaration( - name="get_current_weather", - description="Get the current weather in a given location", - parameters={ - "type_": "OBJECT", - "properties": { - "location": {"type_": "STRING", "description": "The city and state, e.g. San Francisco, CA"}, - "unit": { - "type_": "STRING", - "enum": [ - "celsius", - "fahrenheit", - ], - }, - }, - "required": ["location"], - }, +from haystack.dataclasses.chat_message import ChatMessage, ChatRole, TextContent, ToolCall +from haystack.tools import Tool, create_tool_from_function + +from haystack_integrations.components.generators.google_ai.chat.gemini import ( + GoogleAIGeminiChatGenerator, + _convert_chatmessage_to_google_content, ) +TYPE = "haystack_integrations.components.generators.google_ai.chat.gemini.GoogleAIGeminiChatGenerator" + + +def get_current_weather( + city: Annotated[str, "the city for which to get the weather, e.g. 'San Francisco'"] = "Munich", + unit: Annotated[Literal["Celsius", "Fahrenheit"], "the unit for the temperature"] = "Celsius", +): + """A simple function to get the current weather for a location.""" + return f"Weather report for {city}: 20 {unit}, sunny" + + +@pytest.fixture +def tools(): + tool = create_tool_from_function(get_current_weather) + return [tool] + + +def test_convert_chatmessage_to_google_content(): + chat_message = ChatMessage.from_assistant("Hello, how are you?") + google_content = _convert_chatmessage_to_google_content(chat_message) + assert google_content.parts == [Part(text="Hello, how are you?")] + assert google_content.role == "model" -def test_init(monkeypatch): - monkeypatch.setenv("GOOGLE_API_KEY", "test") + message = ChatMessage.from_user("I have a question") + google_content = _convert_chatmessage_to_google_content(message) + assert google_content.parts == [Part(text="I have a question")] + assert google_content.role == "user" - generation_config = GenerationConfig( - candidate_count=1, - stop_sequences=["stop"], - max_output_tokens=10, - temperature=0.5, - top_p=0.5, - top_k=0.5, + message = ChatMessage.from_assistant( + tool_calls=[ToolCall(id="123", tool_name="weather", arguments={"city": "Paris"})] ) - safety_settings = {HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH} - tool = Tool(function_declarations=[GET_CURRENT_WEATHER_FUNC]) - with patch( - "haystack_integrations.components.generators.google_ai.chat.gemini.genai.configure" - ) as mock_genai_configure: - gemini = GoogleAIGeminiChatGenerator( - generation_config=generation_config, - safety_settings=safety_settings, - tools=[tool], - ) - mock_genai_configure.assert_called_once_with(api_key="test") - assert gemini._model_name == "gemini-1.5-flash" - assert gemini._generation_config == generation_config - assert gemini._safety_settings == safety_settings - assert gemini._tools == [tool] - assert isinstance(gemini._model, GenerativeModel) - - -def test_to_dict(monkeypatch): - monkeypatch.setenv("GOOGLE_API_KEY", "test") - - with patch("haystack_integrations.components.generators.google_ai.chat.gemini.genai.configure"): - gemini = GoogleAIGeminiChatGenerator() - assert gemini.to_dict() == { - "type": "haystack_integrations.components.generators.google_ai.chat.gemini.GoogleAIGeminiChatGenerator", - "init_parameters": { - "api_key": {"env_vars": ["GOOGLE_API_KEY"], "strict": True, "type": "env_var"}, - "model": "gemini-1.5-flash", - "generation_config": None, - "safety_settings": None, - "streaming_callback": None, - "tools": None, - }, - } - - -def test_to_dict_with_param(monkeypatch): - monkeypatch.setenv("GOOGLE_API_KEY", "test") - - generation_config = GenerationConfig( - candidate_count=1, - stop_sequences=["stop"], - max_output_tokens=10, - temperature=0.5, - top_p=0.5, - top_k=2, + google_content = _convert_chatmessage_to_google_content(message) + assert google_content.parts == [ + Part(function_call=genai.protos.FunctionCall(name="weather", args={"city": "Paris"})) + ] + assert google_content.role == "model" + + tool_result = json.dumps({"weather": "sunny", "temperature": "25"}) + message = ChatMessage.from_tool( + tool_result=tool_result, origin=ToolCall(id="123", tool_name="weather", arguments={"city": "Paris"}) ) - safety_settings = {HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH} - tool = Tool(function_declarations=[GET_CURRENT_WEATHER_FUNC]) + google_content = _convert_chatmessage_to_google_content(message) + assert google_content.parts == [ + Part(function_response=genai.protos.FunctionResponse(name="weather", response={"result": tool_result})) + ] + assert google_content.role == "user" - with patch("haystack_integrations.components.generators.google_ai.chat.gemini.genai.configure"): - gemini = GoogleAIGeminiChatGenerator( - generation_config=generation_config, - safety_settings=safety_settings, - tools=[tool], + +def test_convert_chatmessage_to_google_content_invalid(): + message = ChatMessage(_role=ChatRole.ASSISTANT, _content=[]) + with pytest.raises(ValueError): + _convert_chatmessage_to_google_content(message) + + message = ChatMessage( + _role=ChatRole.ASSISTANT, + _content=[TextContent(text="I have an answer"), TextContent(text="I have another answer")], + ) + with pytest.raises(ValueError): + _convert_chatmessage_to_google_content(message) + + message = ChatMessage.from_system("You are a helpful assistant.") + with pytest.raises(ValueError): + _convert_chatmessage_to_google_content(message) + + +class TestGoogleAIGeminiChatGenerator: + def test_init(self, tools, monkeypatch): + monkeypatch.setenv("GOOGLE_API_KEY", "test") + + generation_config = GenerationConfig( + candidate_count=1, + stop_sequences=["stop"], + max_output_tokens=10, + temperature=0.5, + top_p=0.5, + top_k=0.5, ) - assert gemini.to_dict() == { - "type": "haystack_integrations.components.generators.google_ai.chat.gemini.GoogleAIGeminiChatGenerator", - "init_parameters": { - "api_key": {"env_vars": ["GOOGLE_API_KEY"], "strict": True, "type": "env_var"}, - "model": "gemini-1.5-flash", - "generation_config": { - "temperature": 0.5, - "top_p": 0.5, - "top_k": 2, - "candidate_count": 1, - "max_output_tokens": 10, - "stop_sequences": ["stop"], + safety_settings = {HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH} + with patch( + "haystack_integrations.components.generators.google_ai.chat.gemini.genai.configure" + ) as mock_genai_configure: + gemini = GoogleAIGeminiChatGenerator( + generation_config=generation_config, + safety_settings=safety_settings, + tools=tools, + ) + mock_genai_configure.assert_called_once_with(api_key="test") + assert gemini._model_name == "gemini-1.5-flash" + assert gemini._generation_config == generation_config + assert gemini._safety_settings == safety_settings + assert gemini._tools == tools + assert isinstance(gemini._model, GenerativeModel) + + def test_to_dict(self, monkeypatch): + monkeypatch.setenv("GOOGLE_API_KEY", "test") + + with patch("haystack_integrations.components.generators.google_ai.chat.gemini.genai.configure"): + gemini = GoogleAIGeminiChatGenerator() + assert gemini.to_dict() == { + "type": "haystack_integrations.components.generators.google_ai.chat.gemini.GoogleAIGeminiChatGenerator", + "init_parameters": { + "api_key": {"env_vars": ["GOOGLE_API_KEY"], "strict": True, "type": "env_var"}, + "model": "gemini-1.5-flash", + "generation_config": None, + "safety_settings": None, + "streaming_callback": None, + "tools": None, }, - "safety_settings": {10: 3}, - "streaming_callback": None, - "tools": [ - b"\n\xad\x01\n\x13get_current_weather\x12+Get the current weather in a given location\x1ai" - b"\x08\x06:\x1f\n\x04unit\x12\x17\x08\x01*\x07celsius*\nfahrenheit::\n\x08location\x12.\x08" - b"\x01\x1a*The city and state, e.g. San Francisco, CAB\x08location" - ], - }, - } - - -def test_from_dict(monkeypatch): - monkeypatch.setenv("GOOGLE_API_KEY", "test") - - with patch("haystack_integrations.components.generators.google_ai.chat.gemini.genai.configure"): - gemini = GoogleAIGeminiChatGenerator.from_dict( - { - "type": "haystack_integrations.components.generators.google_ai.chat.gemini.GoogleAIGeminiChatGenerator", - "init_parameters": { - "api_key": {"env_vars": ["GOOGLE_API_KEY"], "strict": True, "type": "env_var"}, - "model": "gemini-1.5-flash", - "generation_config": None, - "safety_settings": None, - "streaming_callback": None, - "tools": None, + } + + def test_to_dict_with_param(self, monkeypatch): + monkeypatch.setenv("GOOGLE_API_KEY", "test") + tools = [Tool(name="name", description="description", parameters={"x": {"type": "string"}}, function=print)] + + generation_config = GenerationConfig( + candidate_count=1, + stop_sequences=["stop"], + max_output_tokens=10, + temperature=0.5, + top_p=0.5, + top_k=2, + ) + safety_settings = {HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH} + + with patch("haystack_integrations.components.generators.google_ai.chat.gemini.genai.configure"): + gemini = GoogleAIGeminiChatGenerator( + generation_config=generation_config, + safety_settings=safety_settings, + tools=tools, + ) + assert gemini.to_dict() == { + "type": TYPE, + "init_parameters": { + "api_key": {"env_vars": ["GOOGLE_API_KEY"], "strict": True, "type": "env_var"}, + "model": "gemini-1.5-flash", + "generation_config": { + "temperature": 0.5, + "top_p": 0.5, + "top_k": 2, + "candidate_count": 1, + "max_output_tokens": 10, + "stop_sequences": ["stop"], }, - } + "safety_settings": {10: 3}, + "streaming_callback": None, + "tools": [ + { + "type": "haystack.tools.tool.Tool", + "data": { + "description": "description", + "function": "builtins.print", + "name": "name", + "parameters": {"x": {"type": "string"}}, + }, + } + ], + }, + } + + def test_from_dict(self, monkeypatch): + monkeypatch.setenv("GOOGLE_API_KEY", "test") + + with patch("haystack_integrations.components.generators.google_ai.chat.gemini.genai.configure"): + gemini = GoogleAIGeminiChatGenerator.from_dict( + { + "type": TYPE, + "init_parameters": { + "api_key": {"env_vars": ["GOOGLE_API_KEY"], "strict": True, "type": "env_var"}, + "model": "gemini-1.5-flash", + "generation_config": None, + "safety_settings": None, + "streaming_callback": None, + "tools": None, + }, + } + ) + + assert gemini._model_name == "gemini-1.5-flash" + assert gemini._generation_config is None + assert gemini._safety_settings is None + assert gemini._tools is None + assert isinstance(gemini._model, GenerativeModel) + + def test_from_dict_with_param(self, monkeypatch): + monkeypatch.setenv("GOOGLE_API_KEY", "test") + + with patch("haystack_integrations.components.generators.google_ai.chat.gemini.genai.configure"): + gemini = GoogleAIGeminiChatGenerator.from_dict( + { + "type": TYPE, + "init_parameters": { + "api_key": {"env_vars": ["GOOGLE_API_KEY"], "strict": True, "type": "env_var"}, + "model": "gemini-1.5-flash", + "generation_config": { + "temperature": 0.5, + "top_p": 0.5, + "top_k": 2, + "candidate_count": 1, + "max_output_tokens": 10, + "stop_sequences": ["stop"], + }, + "safety_settings": {10: 3}, + "streaming_callback": None, + "tools": [ + { + "type": "haystack.tools.tool.Tool", + "data": { + "description": "description", + "function": "builtins.print", + "name": "name", + "parameters": {"x": {"type": "string"}}, + }, + } + ], + }, + } + ) + + assert gemini._model_name == "gemini-1.5-flash" + assert gemini._generation_config == GenerationConfig( + candidate_count=1, + stop_sequences=["stop"], + max_output_tokens=10, + temperature=0.5, + top_p=0.5, + top_k=2, + ) + assert gemini._safety_settings == { + HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH + } + assert len(gemini._tools) == 1 + assert gemini._tools[0].name == "name" + assert gemini._tools[0].description == "description" + assert gemini._tools[0].parameters == {"x": {"type": "string"}} + assert gemini._tools[0].function == print + assert isinstance(gemini._model, GenerativeModel) + + def test_serde_in_pipeline(self, monkeypatch): + monkeypatch.setenv("GOOGLE_API_KEY", "test") + tool = Tool(name="name", description="description", parameters={"x": {"type": "string"}}, function=print) + + generator = GoogleAIGeminiChatGenerator( + model="gemini-1.5-flash", + generation_config=GenerationConfig( + temperature=0.6, + stop_sequences=["stop", "words"], + ), + tools=[tool], ) - assert gemini._model_name == "gemini-1.5-flash" - assert gemini._generation_config is None - assert gemini._safety_settings is None - assert gemini._tools is None - assert isinstance(gemini._model, GenerativeModel) - - -def test_from_dict_with_param(monkeypatch): - monkeypatch.setenv("GOOGLE_API_KEY", "test") - - with patch("haystack_integrations.components.generators.google_ai.chat.gemini.genai.configure"): - gemini = GoogleAIGeminiChatGenerator.from_dict( - { - "type": "haystack_integrations.components.generators.google_ai.chat.gemini.GoogleAIGeminiChatGenerator", - "init_parameters": { - "api_key": {"env_vars": ["GOOGLE_API_KEY"], "strict": True, "type": "env_var"}, - "model": "gemini-1.5-flash", - "generation_config": { - "temperature": 0.5, - "top_p": 0.5, - "top_k": 2, - "candidate_count": 1, - "max_output_tokens": 10, - "stop_sequences": ["stop"], + pipeline = Pipeline() + pipeline.add_component("generator", generator) + + pipeline_dict = pipeline.to_dict() + assert pipeline_dict == { + "metadata": {}, + "max_runs_per_component": 100, + "components": { + "generator": { + "type": TYPE, + "init_parameters": { + "api_key": {"env_vars": ["GOOGLE_API_KEY"], "strict": True, "type": "env_var"}, + "model": "gemini-1.5-flash", + "generation_config": { + "candidate_count": None, + "max_output_tokens": None, + "temperature": 0.6, + "stop_sequences": ["stop", "words"], + "top_k": None, + "top_p": None, + }, + "safety_settings": None, + "streaming_callback": None, + "tools": [ + { + "type": "haystack.tools.tool.Tool", + "data": { + "name": "name", + "description": "description", + "parameters": {"x": {"type": "string"}}, + "function": "builtins.print", + }, + } + ], }, - "safety_settings": {10: 3}, - "streaming_callback": None, - "tools": [ - b"\n\xad\x01\n\x13get_current_weather\x12+Get the current weather in a given location\x1ai" - b"\x08\x06:\x1f\n\x04unit\x12\x17\x08\x01*\x07celsius*\nfahrenheit::\n\x08location\x12.\x08" - b"\x01\x1a*The city and state, e.g. San Francisco, CAB\x08location" - ], - }, - } - ) + } + }, + "connections": [], + } - assert gemini._model_name == "gemini-1.5-flash" - assert gemini._generation_config == GenerationConfig( - candidate_count=1, - stop_sequences=["stop"], - max_output_tokens=10, - temperature=0.5, - top_p=0.5, - top_k=2, - ) - assert gemini._safety_settings == {HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH} - assert len(gemini._tools) == 1 - assert len(gemini._tools[0].function_declarations) == 1 - assert gemini._tools[0].function_declarations[0].name == "get_current_weather" - assert gemini._tools[0].function_declarations[0].description == "Get the current weather in a given location" - assert ( - gemini._tools[0].function_declarations[0].parameters.properties["location"].description - == "The city and state, e.g. San Francisco, CA" - ) - assert gemini._tools[0].function_declarations[0].parameters.properties["unit"].enum == ["celsius", "fahrenheit"] - assert gemini._tools[0].function_declarations[0].parameters.required == ["location"] - assert isinstance(gemini._model, GenerativeModel) - - -@pytest.mark.skipif(not os.environ.get("GOOGLE_API_KEY", None), reason="GOOGLE_API_KEY env var not set") -def test_run(): - # We're ignoring the unused function argument check since we must have that argument for the test - # to run successfully, but we don't actually use it. - def get_current_weather(location: str, unit: str = "celsius"): # noqa: ARG001 - return {"weather": "sunny", "temperature": 21.8, "unit": unit} - - get_current_weather_func = FunctionDeclaration.from_function( - get_current_weather, - descriptions={ - "location": "The city, e.g. San Francisco", - "unit": "The temperature unit of measurement, e.g. celsius or fahrenheit", - }, - ) + pipeline_yaml = pipeline.dumps() - tool = Tool(function_declarations=[get_current_weather_func]) - gemini_chat = GoogleAIGeminiChatGenerator(model="gemini-pro", tools=[tool]) - messages = [ChatMessage.from_user("What is the temperature in celsius in Berlin?")] - response = gemini_chat.run(messages=messages) - assert "replies" in response - assert len(response["replies"]) > 0 - assert all(reply.role == ChatRole.ASSISTANT for reply in response["replies"]) - - # check the first response is a function call - chat_message = response["replies"][0] - assert "function_call" in chat_message.meta - assert json.loads(chat_message.text) == {"location": "Berlin", "unit": "celsius"} - - weather = get_current_weather(**json.loads(chat_message.text)) - if hasattr(ChatMessage, "from_function"): - messages += response["replies"] + [ChatMessage.from_function(weather, name="get_current_weather")] - response = gemini_chat.run(messages=messages) + new_pipeline = Pipeline.loads(pipeline_yaml) + assert new_pipeline == pipeline + + def test_convert_to_google_tool(self, tools): + tool = tools[0] + google_tool = GoogleAIGeminiChatGenerator._convert_to_google_tool(tool) + + assert google_tool.name == tool.name + assert google_tool.description == tool.description + + assert google_tool.parameters + # check if default values are removed. This type is not easily inspectable + assert "default" not in str(google_tool.parameters) + + @pytest.mark.integration + @pytest.mark.skipif(not os.environ.get("GOOGLE_API_KEY", None), reason="GOOGLE_API_KEY env var not set") + def test_run(self): + gemini_chat = GoogleAIGeminiChatGenerator() + chat_messages = [ChatMessage.from_user("What's the capital of France")] + response = gemini_chat.run(messages=chat_messages) + assert "replies" in response + assert len(response["replies"]) > 0 + + reply = response["replies"][0] + assert reply.role == ChatRole.ASSISTANT + assert "paris" in reply.text.lower() + + assert not reply.tool_calls + assert not reply.tool_call_results + + assert "usage" in reply.meta + assert "prompt_tokens" in reply.meta["usage"] + assert "completion_tokens" in reply.meta["usage"] + assert "total_tokens" in reply.meta["usage"] + + @pytest.mark.integration + @pytest.mark.skipif(not os.environ.get("GOOGLE_API_KEY", None), reason="GOOGLE_API_KEY env var not set") + def test_run_with_tools(self, tools): + + gemini_chat = GoogleAIGeminiChatGenerator(model="gemini-2.0-flash-exp", tools=tools) + user_message = [ChatMessage.from_user("What is the temperature in celsius in Berlin?")] + response = gemini_chat.run(messages=user_message) assert "replies" in response assert len(response["replies"]) > 0 assert all(reply.role == ChatRole.ASSISTANT for reply in response["replies"]) - # check the second response is not a function call + # check the first response contains a tool call chat_message = response["replies"][0] - assert "function_call" not in chat_message.meta - assert isinstance(chat_message.text, str) + assert chat_message.tool_calls + assert chat_message.tool_calls[0].tool_name == "get_current_weather" + assert chat_message.tool_calls[0].arguments == {"city": "Berlin", "unit": "Celsius"} + + weather = tools[0].invoke(**chat_message.tool_calls[0].arguments) + + messages = ( + user_message + + response["replies"] + + [ChatMessage.from_tool(tool_result=weather, origin=chat_message.tool_calls[0])] + ) + response = gemini_chat.run(messages=messages) + assert "replies" in response + assert len(response["replies"]) > 0 + assert all(reply.role == ChatRole.ASSISTANT for reply in response["replies"]) -@pytest.mark.skipif(not os.environ.get("GOOGLE_API_KEY", None), reason="GOOGLE_API_KEY env var not set") -def test_run_with_streaming_callback(): - streaming_callback_called = False + # check the second response is not a tool call + chat_message = response["replies"][0] + assert not chat_message.tool_calls + assert chat_message.text + assert "berlin" in chat_message.text.lower() - def streaming_callback(_chunk: StreamingChunk) -> None: - nonlocal streaming_callback_called - streaming_callback_called = True + @pytest.mark.integration + @pytest.mark.skipif(not os.environ.get("GOOGLE_API_KEY", None), reason="GOOGLE_API_KEY env var not set") + def test_run_with_streaming_callback_and_tools(self, tools): + streaming_callback_called = False - def get_current_weather(location: str, unit: str = "celsius"): # noqa: ARG001 - return {"weather": "sunny", "temperature": 21.8, "unit": unit} + def streaming_callback(_chunk: StreamingChunk) -> None: + nonlocal streaming_callback_called + streaming_callback_called = True - get_current_weather_func = FunctionDeclaration.from_function( - get_current_weather, - descriptions={ - "location": "The city, e.g. San Francisco", - "unit": "The temperature unit of measurement, e.g. celsius or fahrenheit", - }, - ) + gemini_chat = GoogleAIGeminiChatGenerator( + model="gemini-2.0-flash-exp", tools=tools, streaming_callback=streaming_callback + ) + messages = [ChatMessage.from_user("What is the temperature in celsius in Berlin?")] + response = gemini_chat.run(messages=messages) + assert "replies" in response + assert len(response["replies"]) > 0 + assert all(reply.role == ChatRole.ASSISTANT for reply in response["replies"]) + assert streaming_callback_called - tool = Tool(function_declarations=[get_current_weather_func]) - gemini_chat = GoogleAIGeminiChatGenerator(model="gemini-pro", tools=[tool], streaming_callback=streaming_callback) - messages = [ChatMessage.from_user("What is the temperature in celsius in Berlin?")] - response = gemini_chat.run(messages=messages) - assert "replies" in response - assert len(response["replies"]) > 0 - assert all(reply.role == ChatRole.ASSISTANT for reply in response["replies"]) - assert streaming_callback_called - - # check the first response is a function call - chat_message = response["replies"][0] - assert "function_call" in chat_message.meta - assert json.loads(chat_message.text) == {"location": "Berlin", "unit": "celsius"} - - weather = get_current_weather(**json.loads(chat_message.text)) - if hasattr(ChatMessage, "from_function"): - messages += response["replies"] + [ChatMessage.from_function(weather, name="get_current_weather")] + # check the first response contains a tool call + chat_message = response["replies"][0] + assert chat_message.tool_calls + assert chat_message.tool_calls[0].tool_name == "get_current_weather" + assert chat_message.tool_calls[0].arguments == {"city": "Berlin", "unit": "Celsius"} + + weather = tools[0].invoke(**chat_message.tool_calls[0].arguments) + messages += response["replies"] + [ + ChatMessage.from_tool(tool_result=weather, origin=chat_message.tool_calls[0]) + ] response = gemini_chat.run(messages=messages) assert "replies" in response assert len(response["replies"]) > 0 assert all(reply.role == ChatRole.ASSISTANT for reply in response["replies"]) - # check the second response is not a function call + # check the second response is not a tool call chat_message = response["replies"][0] - assert "function_call" not in chat_message.meta - assert isinstance(chat_message.text, str) - - -@pytest.mark.skipif(not os.environ.get("GOOGLE_API_KEY", None), reason="GOOGLE_API_KEY env var not set") -def test_past_conversation(): - gemini_chat = GoogleAIGeminiChatGenerator(model="gemini-pro") - messages = [ - ChatMessage.from_system("You are a knowledageable mathematician."), - ChatMessage.from_user("What is 2+2?"), - ChatMessage.from_assistant("It's an arithmetic operation."), - ChatMessage.from_user("Yeah, but what's the result?"), - ] - response = gemini_chat.run(messages=messages) - assert "replies" in response - replies = response["replies"] - assert len(replies) > 0 - assert all(reply.role == ChatRole.ASSISTANT for reply in replies) - - assert all("usage" in reply.meta for reply in replies) - assert all("prompt_tokens" in reply.meta["usage"] for reply in replies) - assert all("completion_tokens" in reply.meta["usage"] for reply in replies) - assert all("total_tokens" in reply.meta["usage"] for reply in replies) + assert not chat_message.tool_calls + assert chat_message.text + assert "berlin" in chat_message.text.lower() + + @pytest.mark.integration + @pytest.mark.skipif(not os.environ.get("GOOGLE_API_KEY", None), reason="GOOGLE_API_KEY env var not set") + def test_past_conversation(self): + gemini_chat = GoogleAIGeminiChatGenerator() + messages = [ + ChatMessage.from_system("You are a knowledageable mathematician."), + ChatMessage.from_user("What is 2+2?"), + ChatMessage.from_assistant("It's an arithmetic operation."), + ChatMessage.from_user("Yeah, but what's the result?"), + ] + response = gemini_chat.run(messages=messages) + assert "replies" in response + replies = response["replies"] + assert len(replies) > 0 + assert all(reply.role == ChatRole.ASSISTANT for reply in replies) + + assert all("usage" in reply.meta for reply in replies) + assert all("prompt_tokens" in reply.meta["usage"] for reply in replies) + assert all("completion_tokens" in reply.meta["usage"] for reply in replies) + assert all("total_tokens" in reply.meta["usage"] for reply in replies) From b1d537500cafd1c7f0b5d68f8aa7988a9eb6f52e Mon Sep 17 00:00:00 2001 From: HaystackBot Date: Thu, 23 Jan 2025 14:13:22 +0000 Subject: [PATCH 207/229] Update the changelog --- integrations/google_ai/CHANGELOG.md | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/integrations/google_ai/CHANGELOG.md b/integrations/google_ai/CHANGELOG.md index 5004682ee..0d2f386ff 100644 --- a/integrations/google_ai/CHANGELOG.md +++ b/integrations/google_ai/CHANGELOG.md @@ -1,5 +1,12 @@ # Changelog +## [integrations/google_ai-v5.0.0] - 2025-01-23 + +### 🚀 Features + +- [**breaking**] Google AI - support for Tool + general refactoring (#1316) + + ## [integrations/google_ai-v4.1.0] - 2025-01-16 ### 🧹 Chores From d2a10fd7e09d6b5876626f3bc387af3c6042fc0f Mon Sep 17 00:00:00 2001 From: Sebastian Husch Lee Date: Thu, 23 Jan 2025 06:59:25 -0800 Subject: [PATCH 208/229] refactor: GoogleAIGeminiGenerator - make some attributes public (#1317) * Make variables public * Fix tests * Change back to self._model * Change to self._api_key --- .../components/generators/google_ai/gemini.py | 24 ++++++++-------- .../google_ai/tests/generators/test_gemini.py | 28 +++++++++---------- 2 files changed, 26 insertions(+), 26 deletions(-) diff --git a/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/gemini.py b/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/gemini.py index b032169df..fd3ddd7bf 100644 --- a/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/gemini.py +++ b/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/gemini.py @@ -101,11 +101,11 @@ def __init__( genai.configure(api_key=api_key.resolve_value()) self._api_key = api_key - self._model_name = model - self._generation_config = generation_config - self._safety_settings = safety_settings - self._model = GenerativeModel(self._model_name) - self._streaming_callback = streaming_callback + self.model_name = model + self.generation_config = generation_config + self.safety_settings = safety_settings + self._model = GenerativeModel(self.model_name) + self.streaming_callback = streaming_callback def _generation_config_to_dict(self, config: Union[GenerationConfig, Dict[str, Any]]) -> Dict[str, Any]: if isinstance(config, dict): @@ -126,13 +126,13 @@ def to_dict(self) -> Dict[str, Any]: :returns: Dictionary with serialized data. """ - callback_name = serialize_callable(self._streaming_callback) if self._streaming_callback else None + callback_name = serialize_callable(self.streaming_callback) if self.streaming_callback else None data = default_to_dict( self, api_key=self._api_key.to_dict(), - model=self._model_name, - generation_config=self._generation_config, - safety_settings=self._safety_settings, + model=self.model_name, + generation_config=self.generation_config, + safety_settings=self.safety_settings, streaming_callback=callback_name, ) if (generation_config := data["init_parameters"].get("generation_config")) is not None: @@ -198,13 +198,13 @@ def run( """ # check if streaming_callback is passed - streaming_callback = streaming_callback or self._streaming_callback + streaming_callback = streaming_callback or self.streaming_callback converted_parts = [self._convert_part(p) for p in parts] contents = [Content(parts=converted_parts, role="user")] res = self._model.generate_content( contents=contents, - generation_config=self._generation_config, - safety_settings=self._safety_settings, + generation_config=self.generation_config, + safety_settings=self.safety_settings, stream=streaming_callback is not None, ) self._model.start_chat() diff --git a/integrations/google_ai/tests/generators/test_gemini.py b/integrations/google_ai/tests/generators/test_gemini.py index 07d194a59..eb09514eb 100644 --- a/integrations/google_ai/tests/generators/test_gemini.py +++ b/integrations/google_ai/tests/generators/test_gemini.py @@ -18,7 +18,7 @@ def test_init(monkeypatch): max_output_tokens=10, temperature=0.5, top_p=0.5, - top_k=0.5, + top_k=1, ) safety_settings = {HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH} @@ -28,9 +28,9 @@ def test_init(monkeypatch): safety_settings=safety_settings, ) mock_genai_configure.assert_called_once_with(api_key="test") - assert gemini._model_name == "gemini-1.5-flash" - assert gemini._generation_config == generation_config - assert gemini._safety_settings == safety_settings + assert gemini.model_name == "gemini-1.5-flash" + assert gemini.generation_config == generation_config + assert gemini.safety_settings == safety_settings assert isinstance(gemini._model, GenerativeModel) @@ -105,7 +105,7 @@ def test_from_dict_with_param(monkeypatch): "generation_config": { "temperature": 0.5, "top_p": 0.5, - "top_k": 0.5, + "top_k": 1, "candidate_count": 1, "max_output_tokens": 10, "stop_sequences": ["stop"], @@ -116,16 +116,16 @@ def test_from_dict_with_param(monkeypatch): } ) - assert gemini._model_name == "gemini-1.5-flash" - assert gemini._generation_config == GenerationConfig( + assert gemini.model_name == "gemini-1.5-flash" + assert gemini.generation_config == GenerationConfig( candidate_count=1, stop_sequences=["stop"], max_output_tokens=10, temperature=0.5, top_p=0.5, - top_k=0.5, + top_k=1, ) - assert gemini._safety_settings == {HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH} + assert gemini.safety_settings == {HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH} assert isinstance(gemini._model, GenerativeModel) @@ -141,7 +141,7 @@ def test_from_dict(monkeypatch): "generation_config": { "temperature": 0.5, "top_p": 0.5, - "top_k": 0.5, + "top_k": 1, "candidate_count": 1, "max_output_tokens": 10, "stop_sequences": ["stop"], @@ -152,16 +152,16 @@ def test_from_dict(monkeypatch): } ) - assert gemini._model_name == "gemini-1.5-flash" - assert gemini._generation_config == GenerationConfig( + assert gemini.model_name == "gemini-1.5-flash" + assert gemini.generation_config == GenerationConfig( candidate_count=1, stop_sequences=["stop"], max_output_tokens=10, temperature=0.5, top_p=0.5, - top_k=0.5, + top_k=1, ) - assert gemini._safety_settings == {HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH} + assert gemini.safety_settings == {HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH} assert isinstance(gemini._model, GenerativeModel) From 68ec202baa7cfa9af12bcf610498fbfeba5bc299 Mon Sep 17 00:00:00 2001 From: HaystackBot Date: Thu, 23 Jan 2025 15:00:27 +0000 Subject: [PATCH 209/229] Update the changelog --- integrations/google_ai/CHANGELOG.md | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/integrations/google_ai/CHANGELOG.md b/integrations/google_ai/CHANGELOG.md index 0d2f386ff..63e35e88d 100644 --- a/integrations/google_ai/CHANGELOG.md +++ b/integrations/google_ai/CHANGELOG.md @@ -1,5 +1,12 @@ # Changelog +## [integrations/google_ai-v5.0.1] - 2025-01-23 + +### 🚜 Refactor + +- GoogleAIGeminiGenerator - make some attributes public (#1317) + + ## [integrations/google_ai-v5.0.0] - 2025-01-23 ### 🚀 Features From 3d0dfede3263a0b5ff39d810dc1d7fd0273c78e8 Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Thu, 23 Jan 2025 17:09:58 +0100 Subject: [PATCH 210/229] feat: AmazonBedrockChatGenerator - add tools support (#1304) * AmazonBedrockChatGenerator - add tools support * Remove test not needed * Add actual pipeline integration test with tools * Extract instance functions to free standing * No need to test serde on all models * Add serde test * Fix serde test * Lint * Always pack thinking + tool call into single ChatMessage * Revert accidental changes * Method renaming for Python Zen * Add class pydocs * Don't run pipeline in serde test * Update test_serde_in_pipeline test * Lint --- integrations/amazon_bedrock/pyproject.toml | 1 + .../amazon_bedrock/chat/chat_generator.py | 401 +++++++++++++----- .../tests/test_chat_generator.py | 365 ++++++++++++---- 3 files changed, 573 insertions(+), 194 deletions(-) diff --git a/integrations/amazon_bedrock/pyproject.toml b/integrations/amazon_bedrock/pyproject.toml index 7044eb453..07f8db679 100644 --- a/integrations/amazon_bedrock/pyproject.toml +++ b/integrations/amazon_bedrock/pyproject.toml @@ -48,6 +48,7 @@ dependencies = [ "pytest", "pytest-rerunfailures", "haystack-pydoc-tools", + "jsonschema", ] [tool.hatch.envs.default.scripts] test = "pytest {args:tests}" diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py index bcf11414c..d01dc5f92 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py @@ -1,12 +1,13 @@ import json import logging -from typing import Any, Callable, Dict, List, Optional +from typing import Any, Callable, Dict, List, Optional, Tuple from botocore.config import Config from botocore.eventstream import EventStream from botocore.exceptions import ClientError from haystack import component, default_from_dict, default_to_dict -from haystack.dataclasses import ChatMessage, ChatRole, StreamingChunk +from haystack.dataclasses import ChatMessage, ChatRole, StreamingChunk, ToolCall +from haystack.tools import Tool, _check_duplicate_tool_names, deserialize_tools_inplace from haystack.utils.auth import Secret, deserialize_secrets_inplace from haystack.utils.callable_serialization import deserialize_callable, serialize_callable @@ -19,6 +20,215 @@ logger = logging.getLogger(__name__) +def _format_tools(tools: Optional[List[Tool]] = None) -> Optional[Dict[str, Any]]: + """ + Format Haystack Tool(s) to Amazon Bedrock toolConfig format. + + :param tools: List of Tool objects to format + :return: Dictionary in Bedrock toolConfig format or None if no tools + """ + if not tools: + return None + + tool_specs = [] + for tool in tools: + tool_specs.append( + {"toolSpec": {"name": tool.name, "description": tool.description, "inputSchema": {"json": tool.parameters}}} + ) + + return {"tools": tool_specs} if tool_specs else None + + +def _format_messages(messages: List[ChatMessage]) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]: + """ + Format a list of ChatMessages to the format expected by Bedrock API. + Separates system messages and handles tool results and tool calls. + + :param messages: List of ChatMessages to format + :return: Tuple of (system_prompts, non_system_messages) in Bedrock format + """ + system_prompts = [] + non_system_messages = [] + + for msg in messages: + if msg.is_from(ChatRole.SYSTEM): + system_prompts.append({"text": msg.text}) + continue + + # Handle tool results - must role these as user messages + if msg.tool_call_results: + tool_results = [] + for result in msg.tool_call_results: + try: + json_result = json.loads(result.result) + content = [{"json": json_result}] + except json.JSONDecodeError: + content = [{"text": result.result}] + + tool_results.append( + { + "toolResult": { + "toolUseId": result.origin.id, + "content": content, + **({"status": "error"} if result.error else {}), + } + } + ) + non_system_messages.append({"role": "user", "content": tool_results}) + continue + + content = [] + # Handle text content + if msg.text: + content.append({"text": msg.text}) + + # Handle tool calls + if msg.tool_calls: + for tool_call in msg.tool_calls: + content.append( + {"toolUse": {"toolUseId": tool_call.id, "name": tool_call.tool_name, "input": tool_call.arguments}} + ) + + if content: # Only add message if it has content + non_system_messages.append({"role": msg.role.value, "content": content}) + + return system_prompts, non_system_messages + + +def _parse_completion_response(response_body: Dict[str, Any], model: str) -> List[ChatMessage]: + """ + Parse a Bedrock response to a list of ChatMessage objects. + + :param response_body: Raw response from Bedrock API + :param model: The model ID used for generation + :return: List of ChatMessage objects + """ + replies = [] + if "output" in response_body and "message" in response_body["output"]: + message = response_body["output"]["message"] + if message["role"] == "assistant": + content_blocks = message["content"] + + # Common meta information + base_meta = { + "model": model, + "index": 0, + "finish_reason": response_body.get("stopReason"), + "usage": { + # OpenAI's format for usage for cross ChatGenerator compatibility + "prompt_tokens": response_body.get("usage", {}).get("inputTokens", 0), + "completion_tokens": response_body.get("usage", {}).get("outputTokens", 0), + "total_tokens": response_body.get("usage", {}).get("totalTokens", 0), + }, + } + + # Process all content blocks and combine them into a single message + text_content = [] + tool_calls = [] + for content_block in content_blocks: + if "text" in content_block: + text_content.append(content_block["text"]) + elif "toolUse" in content_block: + # Convert tool use to ToolCall + tool_use = content_block["toolUse"] + tool_call = ToolCall( + id=tool_use.get("toolUseId"), + tool_name=tool_use.get("name"), + arguments=tool_use.get("input", {}), + ) + tool_calls.append(tool_call) + + # Create a single ChatMessage with combined text and tool calls + replies.append(ChatMessage.from_assistant(" ".join(text_content), tool_calls=tool_calls, meta=base_meta)) + + return replies + + +def _parse_streaming_response( + response_stream: EventStream, + streaming_callback: Callable[[StreamingChunk], None], + model: str, +) -> List[ChatMessage]: + """ + Parse a streaming response from Bedrock. + + :param response_stream: EventStream from Bedrock API + :param streaming_callback: Callback for streaming chunks + :param model: The model ID used for generation + :return: List of ChatMessage objects + """ + replies = [] + current_content = "" + current_tool_call: Optional[Dict[str, Any]] = None + base_meta = { + "model": model, + "index": 0, + } + + for event in response_stream: + if "contentBlockStart" in event: + # Reset accumulators for new message + current_content = "" + current_tool_call = None + block_start = event["contentBlockStart"] + if "start" in block_start and "toolUse" in block_start["start"]: + tool_start = block_start["start"]["toolUse"] + current_tool_call = { + "id": tool_start["toolUseId"], + "name": tool_start["name"], + "arguments": "", # Will accumulate deltas as string + } + + elif "contentBlockDelta" in event: + delta = event["contentBlockDelta"]["delta"] + if "text" in delta: + delta_text = delta["text"] + current_content += delta_text + streaming_chunk = StreamingChunk(content=delta_text, meta=None) + streaming_callback(streaming_chunk) + elif "toolUse" in delta and current_tool_call: + # Accumulate tool use input deltas + current_tool_call["arguments"] += delta["toolUse"].get("input", "") + + elif "contentBlockStop" in event: + if current_tool_call: + # Parse accumulated input if it's a JSON string + try: + input_json = json.loads(current_tool_call["arguments"]) + current_tool_call["arguments"] = input_json + except json.JSONDecodeError: + # Keep as string if not valid JSON + pass + + tool_call = ToolCall( + id=current_tool_call["id"], + tool_name=current_tool_call["name"], + arguments=current_tool_call["arguments"], + ) + replies.append(ChatMessage.from_assistant("", tool_calls=[tool_call], meta=base_meta.copy())) + elif current_content: + replies.append(ChatMessage.from_assistant(current_content, meta=base_meta.copy())) + + elif "messageStop" in event: + # Update finish reason for all replies + for reply in replies: + reply.meta["finish_reason"] = event["messageStop"].get("stopReason") + + elif "metadata" in event: + metadata = event["metadata"] + # Update usage stats for all replies + for reply in replies: + if "usage" in metadata: + usage = metadata["usage"] + reply.meta["usage"] = { + "prompt_tokens": usage.get("inputTokens", 0), + "completion_tokens": usage.get("outputTokens", 0), + "total_tokens": usage.get("totalTokens", 0), + } + + return replies + + @component class AmazonBedrockChatGenerator: """ @@ -41,6 +251,64 @@ class AmazonBedrockChatGenerator: client = AmazonBedrockChatGenerator(model="anthropic.claude-3-5-sonnet-20240620-v1:0", streaming_callback=print_streaming_chunk) client.run(messages, generation_kwargs={"max_tokens": 512}) + ``` + + ### Tool usage example + # AmazonBedrockChatGenerator supports Haystack's unified tool architecture, allowing tools to be used + # across different chat generators. The same tool definitions and usage patterns work consistently + # whether using Amazon Bedrock, OpenAI, Ollama, or any other supported LLM providers. + + ```python + from haystack.dataclasses import ChatMessage + from haystack.tools import Tool + from haystack_integrations.components.generators.amazon_bedrock import AmazonBedrockChatGenerator + + def weather(city: str): + return f'The weather in {city} is sunny and 32°C' + + # Define tool parameters + tool_parameters = { + "type": "object", + "properties": {"city": {"type": "string"}}, + "required": ["city"] + } + + # Create weather tool + weather_tool = Tool( + name="weather", + description="useful to determine the weather in a given location", + parameters=tool_parameters, + function=weather + ) + + # Initialize generator with tool + client = AmazonBedrockChatGenerator( + model="anthropic.claude-3-5-sonnet-20240620-v1:0", + tools=[weather_tool] + ) + + # Run initial query + messages = [ChatMessage.from_user("What's the weather like in Paris?")] + results = client.run(messages=messages) + + # Get tool call from response + tool_message = next(msg for msg in results["replies"] if msg.tool_call) + tool_call = tool_message.tool_call + + # Execute tool and send result back + weather_result = weather(**tool_call.arguments) + new_messages = [ + messages[0], + tool_message, + ChatMessage.from_tool(tool_result=weather_result, origin=tool_call) + ] + + # Get final response + final_result = client.run(new_messages) + print(final_result["replies"][0].text) + + > Based on the information I've received, I can tell you that the weather in Paris is + > currently sunny with a temperature of 32°C (which is about 90°F). ``` @@ -70,6 +338,7 @@ def __init__( stop_words: Optional[List[str]] = None, streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, boto3_config: Optional[Dict[str, Any]] = None, + tools: Optional[List[Tool]] = None, ): """ Initializes the `AmazonBedrockChatGenerator` with the provided parameters. The parameters are passed to the @@ -103,6 +372,7 @@ def __init__( [StreamingChunk](https://docs.haystack.deepset.ai/docs/data-classes#streamingchunk) object and switches the streaming mode on. :param boto3_config: The configuration for the boto3 client. + :param tools: A list of Tool objects that the model can use. Each tool should have a unique name. :raises ValueError: If the model name is empty or None. :raises AmazonBedrockConfigurationError: If the AWS environment is not configured correctly or the model is @@ -120,6 +390,8 @@ def __init__( self.stop_words = stop_words or [] self.streaming_callback = streaming_callback self.boto3_config = boto3_config + _check_duplicate_tool_names(tools) + self.tools = tools def resolve_secret(secret: Optional[Secret]) -> Optional[str]: return secret.resolve_value() if secret else None @@ -155,6 +427,7 @@ def to_dict(self) -> Dict[str, Any]: Dictionary with serialized data. """ callback_name = serialize_callable(self.streaming_callback) if self.streaming_callback else None + serialized_tools = [tool.to_dict() for tool in self.tools] if self.tools else None return default_to_dict( self, aws_access_key_id=self.aws_access_key_id.to_dict() if self.aws_access_key_id else None, @@ -167,6 +440,7 @@ def to_dict(self) -> Dict[str, Any]: generation_kwargs=self.generation_kwargs, streaming_callback=callback_name, boto3_config=self.boto3_config, + tools=serialized_tools, ) @classmethod @@ -186,6 +460,7 @@ def from_dict(cls, data: Dict[str, Any]) -> "AmazonBedrockChatGenerator": data["init_parameters"], ["aws_access_key_id", "aws_secret_access_key", "aws_session_token", "aws_region_name", "aws_profile_name"], ) + deserialize_tools_inplace(data["init_parameters"], key="tools") return default_from_dict(cls, data) @component.output_types(replies=List[ChatMessage]) @@ -194,6 +469,7 @@ def run( messages: List[ChatMessage], streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, generation_kwargs: Optional[Dict[str, Any]] = None, + tools: Optional[List[Tool]] = None, ): generation_kwargs = generation_kwargs or {} @@ -209,20 +485,19 @@ def run( if key in merged_kwargs } - # Extract tool configuration if present - # See https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ToolConfiguration.html + # Handle tools - either toolConfig or Haystack Tool objects but not both + tools = tools or self.tools + _check_duplicate_tool_names(tools) tool_config = merged_kwargs.pop("toolConfig", None) + if tools: + # Format Haystack tools to Bedrock format + tool_config = _format_tools(tools) # Any remaining kwargs go to additionalModelRequestFields additional_fields = merged_kwargs if merged_kwargs else None - # Prepare system prompts and messages - system_prompts = [] - if messages and messages[0].is_from(ChatRole.SYSTEM): - system_prompts = [{"text": messages[0].text}] - messages = messages[1:] - - messages_list = [{"role": msg.role.value, "content": [{"text": msg.text}]} for msg in messages] + # Format messages to Bedrock format + system_prompts, messages_list = _format_messages(messages) # Build API parameters params = { @@ -245,112 +520,12 @@ def run( if not response_stream: msg = "No stream found in the response." raise AmazonBedrockInferenceError(msg) - replies = self.process_streaming_response(response_stream, callback) + replies = _parse_streaming_response(response_stream, callback, self.model) else: response = self.client.converse(**params) - replies = self.extract_replies_from_response(response) + replies = _parse_completion_response(response, self.model) except ClientError as exception: msg = f"Could not generate inference for Amazon Bedrock model {self.model} due: {exception}" raise AmazonBedrockInferenceError(msg) from exception return {"replies": replies} - - def extract_replies_from_response(self, response_body: Dict[str, Any]) -> List[ChatMessage]: - replies = [] - if "output" in response_body and "message" in response_body["output"]: - message = response_body["output"]["message"] - if message["role"] == "assistant": - content_blocks = message["content"] - - # Common meta information - base_meta = { - "model": self.model, - "index": 0, - "finish_reason": response_body.get("stopReason"), - "usage": { - # OpenAI's format for usage for cross ChatGenerator compatibility - "prompt_tokens": response_body.get("usage", {}).get("inputTokens", 0), - "completion_tokens": response_body.get("usage", {}).get("outputTokens", 0), - "total_tokens": response_body.get("usage", {}).get("totalTokens", 0), - }, - } - - # Process each content block separately - for content_block in content_blocks: - if "text" in content_block: - replies.append(ChatMessage.from_assistant(content_block["text"], meta=base_meta.copy())) - elif "toolUse" in content_block: - replies.append( - ChatMessage.from_assistant(json.dumps(content_block["toolUse"]), meta=base_meta.copy()) - ) - return replies - - def process_streaming_response( - self, response_stream: EventStream, streaming_callback: Callable[[StreamingChunk], None] - ) -> List[ChatMessage]: - replies = [] - current_content = "" - current_tool_use = None - base_meta = { - "model": self.model, - "index": 0, - } - - for event in response_stream: - if "contentBlockStart" in event: - # Reset accumulators for new message - current_content = "" - current_tool_use = None - block_start = event["contentBlockStart"] - if "start" in block_start and "toolUse" in block_start["start"]: - tool_start = block_start["start"]["toolUse"] - current_tool_use = { - "toolUseId": tool_start["toolUseId"], - "name": tool_start["name"], - "input": "", # Will accumulate deltas as string - } - - elif "contentBlockDelta" in event: - delta = event["contentBlockDelta"]["delta"] - if "text" in delta: - delta_text = delta["text"] - current_content += delta_text - streaming_chunk = StreamingChunk(content=delta_text, meta=None) - # it only makes sense to call callback on text deltas - streaming_callback(streaming_chunk) - elif "toolUse" in delta and current_tool_use: - # Accumulate tool use input deltas - current_tool_use["input"] += delta["toolUse"].get("input", "") - elif "contentBlockStop" in event: - if current_tool_use: - # Parse accumulated input if it's a JSON string - try: - input_json = json.loads(current_tool_use["input"]) - current_tool_use["input"] = input_json - except json.JSONDecodeError: - # Keep as string if not valid JSON - pass - - tool_content = json.dumps(current_tool_use) - replies.append(ChatMessage.from_assistant(tool_content, meta=base_meta.copy())) - elif current_content: - replies.append(ChatMessage.from_assistant(current_content, meta=base_meta.copy())) - - elif "messageStop" in event: - # not 100% correct for multiple messages but no way around it - for reply in replies: - reply.meta["finish_reason"] = event["messageStop"].get("stopReason") - - elif "metadata" in event: - metadata = event["metadata"] - # not 100% correct for multiple messages but no way around it - for reply in replies: - if "usage" in metadata: - usage = metadata["usage"] - reply.meta["usage"] = { - "prompt_tokens": usage.get("inputTokens", 0), - "completion_tokens": usage.get("outputTokens", 0), - "total_tokens": usage.get("totalTokens", 0), - } - - return replies diff --git a/integrations/amazon_bedrock/tests/test_chat_generator.py b/integrations/amazon_bedrock/tests/test_chat_generator.py index c2122163c..d5d19ea64 100644 --- a/integrations/amazon_bedrock/tests/test_chat_generator.py +++ b/integrations/amazon_bedrock/tests/test_chat_generator.py @@ -1,11 +1,17 @@ -import json from typing import Any, Dict, Optional import pytest +from haystack import Pipeline from haystack.components.generators.utils import print_streaming_chunk +from haystack.components.tools import ToolInvoker from haystack.dataclasses import ChatMessage, ChatRole, StreamingChunk +from haystack.tools import Tool from haystack_integrations.components.generators.amazon_bedrock import AmazonBedrockChatGenerator +from haystack_integrations.components.generators.amazon_bedrock.chat.chat_generator import ( + _parse_completion_response, + _parse_streaming_response, +) KLASS = "haystack_integrations.components.generators.amazon_bedrock.chat.chat_generator.AmazonBedrockChatGenerator" MODELS_TO_TEST = [ @@ -23,6 +29,11 @@ STREAMING_TOOL_MODELS = ["anthropic.claude-3-5-sonnet-20240620-v1:0", "cohere.command-r-plus-v1:0"] +def weather(city: str): + """Get weather for a given city.""" + return f"The weather in {city} is sunny and 32°C" + + @pytest.fixture def chat_messages(): messages = [ @@ -32,6 +43,18 @@ def chat_messages(): return messages +@pytest.fixture +def tools(): + tool_parameters = {"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]} + tool = Tool( + name="weather", + description="useful to determine the weather in a given location", + parameters=tool_parameters, + function=weather, + ) + return [tool] + + @pytest.mark.parametrize( "boto3_config", [ @@ -64,6 +87,7 @@ def test_to_dict(mock_boto3_session, boto3_config): "stop_words": [], "streaming_callback": "haystack.components.generators.utils.print_streaming_chunk", "boto3_config": boto3_config, + "tools": None, }, } @@ -96,6 +120,7 @@ def test_from_dict(mock_boto3_session: Any, boto3_config: Optional[Dict[str, Any "generation_kwargs": {"temperature": 0.7}, "streaming_callback": "haystack.components.generators.utils.print_streaming_chunk", "boto3_config": boto3_config, + "tools": None, }, } ) @@ -206,7 +231,9 @@ def streaming_callback(chunk: StreamingChunk): @pytest.mark.integration def test_tools_use(self, model_name): """ - Test function calling with AWS Bedrock Anthropic adapter + Test tools use with passing the generation_kwargs={"toolConfig": tool_config} + and not the tools parameter. We support this because some users might want to use the toolConfig + parameter to pass the tool configuration to the model. """ # See https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ToolConfiguration.html tool_config = { @@ -244,43 +271,33 @@ def test_tools_use(self, model_name): assert isinstance(replies, list), "Replies is not a list" assert len(replies) > 0, "No replies received" - first_reply = replies[0] - assert isinstance(first_reply, ChatMessage), "First reply is not a ChatMessage instance" - assert first_reply.text, "First reply has no content" - assert ChatMessage.is_from(first_reply, ChatRole.ASSISTANT), "First reply is not from the assistant" - assert first_reply.meta, "First reply has no metadata" - - # Some models return thinking message as first and the second one as the tool call - if len(replies) > 1: - second_reply = replies[1] - assert isinstance(second_reply, ChatMessage), "Second reply is not a ChatMessage instance" - assert second_reply.text, "Second reply has no content" - assert ChatMessage.is_from(second_reply, ChatRole.ASSISTANT), "Second reply is not from the assistant" - tool_call = json.loads(second_reply.text) - assert "toolUseId" in tool_call, "Tool call does not contain 'toolUseId' key" - assert tool_call["name"] == "top_song", f"Tool call {tool_call} does not contain the correct 'name' value" - assert "input" in tool_call, f"Tool call {tool_call} does not contain 'input' key" - assert ( - tool_call["input"]["sign"] == "WZPZ" - ), f"Tool call {tool_call} does not contain the correct 'input' value" - else: - # case where the model returns the tool call as the first message - # double check that the tool call is correct - tool_call = json.loads(first_reply.text) - assert "toolUseId" in tool_call, "Tool call does not contain 'toolUseId' key" - assert tool_call["name"] == "top_song", f"Tool call {tool_call} does not contain the correct 'name' value" - assert "input" in tool_call, f"Tool call {tool_call} does not contain 'input' key" - assert ( - tool_call["input"]["sign"] == "WZPZ" - ), f"Tool call {tool_call} does not contain the correct 'input' value" + # Find the message with tool calls as in some models it is the first message, in some second + tool_message = None + for message in replies: + if message.tool_call: # Using tool_call instead of tool_calls to match existing code + tool_message = message + break + + assert tool_message is not None, "No message with tool call found" + assert isinstance(tool_message, ChatMessage), "Tool message is not a ChatMessage instance" + assert ChatMessage.is_from(tool_message, ChatRole.ASSISTANT), "Tool message is not from the assistant" + + tool_call = tool_message.tool_call + assert tool_call.id, "Tool call does not contain value for 'id' key" + assert tool_call.tool_name == "top_song", f"{tool_call} does not contain the correct 'tool_name' value" + assert tool_call.arguments, f"Tool call {tool_call} does not contain 'arguments' value" + assert ( + tool_call.arguments["sign"] == "WZPZ" + ), f"Tool call {tool_call} does not contain the correct 'arguments' value" @pytest.mark.parametrize("model_name", STREAMING_TOOL_MODELS) @pytest.mark.integration def test_tools_use_with_streaming(self, model_name): """ - Test function calling with AWS Bedrock Anthropic adapter + Test tools use with streaming but with passing the generation_kwargs={"toolConfig": tool_config} + and not the tools parameter. We support this because some users might want to use the toolConfig + parameter to pass the tool configuration to the model. """ - # See https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ToolConfiguration.html tool_config = { "tools": [ { @@ -304,12 +321,10 @@ def test_tools_use_with_streaming(self, model_name): } } ], - # See https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ToolChoice.html "toolChoice": {"auto": {}}, } - messages = [] - messages.append(ChatMessage.from_user("What is the most popular song on WZPZ?")) + messages = [ChatMessage.from_user("What is the most popular song on WZPZ?")] client = AmazonBedrockChatGenerator(model=model_name, streaming_callback=print_streaming_chunk) response = client.run(messages=messages, generation_kwargs={"toolConfig": tool_config}) replies = response["replies"] @@ -322,36 +337,28 @@ def test_tools_use_with_streaming(self, model_name): assert ChatMessage.is_from(first_reply, ChatRole.ASSISTANT), "First reply is not from the assistant" assert first_reply.meta, "First reply has no metadata" - # Some models return thinking message as first and the second one as the tool call - if len(replies) > 1: - second_reply = replies[1] - assert isinstance(second_reply, ChatMessage), "Second reply is not a ChatMessage instance" - assert second_reply.text, "Second reply has no content" - assert ChatMessage.is_from(second_reply, ChatRole.ASSISTANT), "Second reply is not from the assistant" - tool_call = json.loads(second_reply.text) - assert "toolUseId" in tool_call, "Tool call does not contain 'toolUseId' key" - assert tool_call["name"] == "top_song", f"Tool call {tool_call} does not contain the correct 'name' value" - assert "input" in tool_call, f"Tool call {tool_call} does not contain 'input' key" - assert ( - tool_call["input"]["sign"] == "WZPZ" - ), f"Tool call {tool_call} does not contain the correct 'input' value" - else: - # case where the model returns the tool call as the first message - # double check that the tool call is correct - tool_call = json.loads(first_reply.text) - assert "toolUseId" in tool_call, "Tool call does not contain 'toolUseId' key" - assert tool_call["name"] == "top_song", f"Tool call {tool_call} does not contain the correct 'name' value" - assert "input" in tool_call, f"Tool call {tool_call} does not contain 'input' key" - assert ( - tool_call["input"]["sign"] == "WZPZ" - ), f"Tool call {tool_call} does not contain the correct 'input' value" + # Find the message containing the tool call + tool_message = None + for message in replies: + if message.tool_call: + tool_message = message + break + + assert tool_message is not None, "No message with tool call found" + assert isinstance(tool_message, ChatMessage), "Tool message is not a ChatMessage instance" + assert ChatMessage.is_from(tool_message, ChatRole.ASSISTANT), "Tool message is not from the assistant" + + tool_call = tool_message.tool_call + assert tool_call.id, "Tool call does not contain value for 'id' key" + assert tool_call.tool_name == "top_song", f"{tool_call} does not contain the correct 'tool_name' value" + assert tool_call.arguments, f"{tool_call} does not contain 'arguments' value" + assert tool_call.arguments["sign"] == "WZPZ", f"{tool_call} does not contain the correct 'input' value" def test_extract_replies_from_response(self, mock_boto3_session): """ Test that extract_replies_from_response correctly processes both text and tool use responses """ - generator = AmazonBedrockChatGenerator(model="anthropic.claude-3-5-sonnet-20240620-v1:0") - + model = "anthropic.claude-3-5-sonnet-20240620-v1:0" # Test case 1: Simple text response text_response = { "output": {"message": {"role": "assistant", "content": [{"text": "This is a test response"}]}}, @@ -359,11 +366,11 @@ def test_extract_replies_from_response(self, mock_boto3_session): "usage": {"inputTokens": 10, "outputTokens": 20, "totalTokens": 30}, } - replies = generator.extract_replies_from_response(text_response) + replies = _parse_completion_response(text_response, model) assert len(replies) == 1 assert replies[0].text == "This is a test response" assert replies[0].role == ChatRole.ASSISTANT - assert replies[0].meta["model"] == "anthropic.claude-3-5-sonnet-20240620-v1:0" + assert replies[0].meta["model"] == model assert replies[0].meta["finish_reason"] == "complete" assert replies[0].meta["usage"] == {"prompt_tokens": 10, "completion_tokens": 20, "total_tokens": 30} @@ -379,12 +386,12 @@ def test_extract_replies_from_response(self, mock_boto3_session): "usage": {"inputTokens": 15, "outputTokens": 25, "totalTokens": 40}, } - replies = generator.extract_replies_from_response(tool_response) + replies = _parse_completion_response(tool_response, model) assert len(replies) == 1 - tool_content = json.loads(replies[0].text) - assert tool_content["toolUseId"] == "123" - assert tool_content["name"] == "test_tool" - assert tool_content["input"] == {"key": "value"} + tool_content = replies[0].tool_call + assert tool_content.id == "123" + assert tool_content.tool_name == "test_tool" + assert tool_content.arguments == {"key": "value"} assert replies[0].meta["finish_reason"] == "tool_call" assert replies[0].meta["usage"] == {"prompt_tokens": 15, "completion_tokens": 25, "total_tokens": 40} @@ -403,19 +410,19 @@ def test_extract_replies_from_response(self, mock_boto3_session): "usage": {"inputTokens": 25, "outputTokens": 35, "totalTokens": 60}, } - replies = generator.extract_replies_from_response(mixed_response) - assert len(replies) == 2 + replies = _parse_completion_response(mixed_response, model) + assert len(replies) == 1 assert replies[0].text == "Let me help you with that. I'll use the search tool to find the answer." - tool_content = json.loads(replies[1].text) - assert tool_content["toolUseId"] == "456" - assert tool_content["name"] == "search_tool" - assert tool_content["input"] == {"query": "test"} + tool_content = replies[0].tool_call + assert tool_content.id == "456" + assert tool_content.tool_name == "search_tool" + assert tool_content.arguments == {"query": "test"} def test_process_streaming_response(self, mock_boto3_session): """ Test that process_streaming_response correctly handles streaming events and accumulates responses """ - generator = AmazonBedrockChatGenerator(model="anthropic.claude-3-5-sonnet-20240620-v1:0") + model = "anthropic.claude-3-5-sonnet-20240620-v1:0" streaming_chunks = [] @@ -436,7 +443,7 @@ def test_callback(chunk: StreamingChunk): {"metadata": {"usage": {"inputTokens": 10, "outputTokens": 20, "totalTokens": 30}}}, ] - replies = generator.process_streaming_response(events, test_callback) + replies = _parse_streaming_response(events, test_callback, model) # Verify streaming chunks were received for text content assert len(streaming_chunks) == 2 @@ -447,12 +454,208 @@ def test_callback(chunk: StreamingChunk): assert len(replies) == 2 # Check text reply assert replies[0].text == "Let me help you." - assert replies[0].meta["model"] == "anthropic.claude-3-5-sonnet-20240620-v1:0" + assert replies[0].meta["model"] == model assert replies[0].meta["finish_reason"] == "complete" assert replies[0].meta["usage"] == {"prompt_tokens": 10, "completion_tokens": 20, "total_tokens": 30} # Check tool use reply - tool_content = json.loads(replies[1].text) - assert tool_content["toolUseId"] == "123" - assert tool_content["name"] == "search_tool" - assert tool_content["input"] == {"query": "test"} + tool_content = replies[1].tool_call + assert tool_content.id == "123" + assert tool_content.tool_name == "search_tool" + assert tool_content.arguments == {"query": "test"} + + @pytest.mark.parametrize("model_name", MODELS_TO_TEST_WITH_TOOLS) + @pytest.mark.integration + def test_live_run_with_tools(self, model_name, tools): + """ + Integration test that the AmazonBedrockChatGenerator component can run with tools. Here we are using the + Haystack tools parameter to pass the tool configuration to the model. + """ + initial_messages = [ChatMessage.from_user("What's the weather like in Paris?")] + component = AmazonBedrockChatGenerator(model=model_name, tools=tools) + results = component.run(messages=initial_messages) + + assert len(results["replies"]) > 0, "No replies received" + + # Find the message with tool calls + tool_message = None + for message in results["replies"]: + if message.tool_call: + tool_message = message + break + + assert tool_message is not None, "No message with tool call found" + assert isinstance(tool_message, ChatMessage), "Tool message is not a ChatMessage instance" + assert ChatMessage.is_from(tool_message, ChatRole.ASSISTANT), "Tool message is not from the assistant" + + tool_call = tool_message.tool_call + assert tool_call.id, "Tool call does not contain value for 'id' key" + assert tool_call.tool_name == "weather" + assert tool_call.arguments == {"city": "Paris"} + assert tool_message.meta["finish_reason"] == "tool_use" + + new_messages = [ + initial_messages[0], + tool_message, + ChatMessage.from_tool(tool_result="22° C", origin=tool_call), + ] + # Pass the tool result to the model to get the final response + results = component.run(new_messages) + + assert len(results["replies"]) == 1 + final_message = results["replies"][0] + assert not final_message.tool_call + assert len(final_message.text) > 0 + assert "paris" in final_message.text.lower() + + @pytest.mark.parametrize("model_name", STREAMING_TOOL_MODELS) + @pytest.mark.integration + def test_live_run_with_tools_streaming(self, model_name, tools): + """ + Integration test that the AmazonBedrockChatGenerator component can run with the Haystack tools parameter. + and the streaming_callback parameter to get the streaming response. + """ + initial_messages = [ChatMessage.from_user("What's the weather like in Paris?")] + component = AmazonBedrockChatGenerator(model=model_name, tools=tools, streaming_callback=print_streaming_chunk) + results = component.run(messages=initial_messages) + + assert len(results["replies"]) > 0, "No replies received" + + # Find the message with tool calls + tool_message = None + for message in results["replies"]: + if message.tool_call: + tool_message = message + break + + assert tool_message is not None, "No message with tool call found" + assert isinstance(tool_message, ChatMessage), "Tool message is not a ChatMessage instance" + assert ChatMessage.is_from(tool_message, ChatRole.ASSISTANT), "Tool message is not from the assistant" + + tool_call = tool_message.tool_call + assert tool_call.id, "Tool call does not contain value for 'id' key" + assert tool_call.tool_name == "weather" + assert tool_call.arguments == {"city": "Paris"} + assert tool_message.meta["finish_reason"] == "tool_use" + + new_messages = [ + initial_messages[0], + tool_message, + ChatMessage.from_tool(tool_result="22° C", origin=tool_call), + ] + # Pass the tool result to the model to get the final response + results = component.run(new_messages) + + assert len(results["replies"]) == 1 + final_message = results["replies"][0] + assert not final_message.tool_call + assert len(final_message.text) > 0 + assert "paris" in final_message.text.lower() + + @pytest.mark.parametrize("model_name", [MODELS_TO_TEST_WITH_TOOLS[0]]) # just one model is enough + @pytest.mark.integration + def test_pipeline_with_amazon_bedrock_chat_generator(self, model_name, tools): + """ + Test that the AmazonBedrockChatGenerator component can be used in a pipeline + """ + + pipeline = Pipeline() + pipeline.add_component("generator", AmazonBedrockChatGenerator(model=model_name, tools=tools)) + pipeline.add_component("tool_invoker", ToolInvoker(tools=tools)) + + pipeline.connect("generator", "tool_invoker") + + results = pipeline.run( + data={"generator": {"messages": [ChatMessage.from_user("What's the weather like in Paris?")]}} + ) + + assert ( + "The weather in Paris is sunny and 32°C" + == results["tool_invoker"]["tool_messages"][0].tool_call_result.result + ) + + def test_serde_in_pipeline(self, mock_boto3_session, monkeypatch): + """ + Test serialization/deserialization of AmazonBedrockChatGenerator in a Pipeline, + including YAML conversion and detailed dictionary validation + """ + # Set mock AWS credentials + monkeypatch.setenv("AWS_ACCESS_KEY_ID", "test-key") + monkeypatch.setenv("AWS_SECRET_ACCESS_KEY", "test-secret") + monkeypatch.setenv("AWS_DEFAULT_REGION", "us-east-1") + + # Create a test tool + tool = Tool( + name="weather", + description="useful to determine the weather in a given location", + parameters={"city": {"type": "string"}}, + function=weather, + ) + + # Create generator with specific configuration + generator = AmazonBedrockChatGenerator( + model="anthropic.claude-3-5-sonnet-20240620-v1:0", + generation_kwargs={"temperature": 0.7}, + stop_words=["eviscerate"], + streaming_callback=print_streaming_chunk, + tools=[tool], + ) + + # Create and configure pipeline + pipeline = Pipeline() + pipeline.add_component("generator", generator) + + # Get pipeline dictionary and verify its structure + pipeline_dict = pipeline.to_dict() + assert pipeline_dict == { + "metadata": {}, + "max_runs_per_component": 100, + "components": { + "generator": { + "type": KLASS, + "init_parameters": { + "aws_access_key_id": {"type": "env_var", "env_vars": ["AWS_ACCESS_KEY_ID"], "strict": False}, + "aws_secret_access_key": { + "type": "env_var", + "env_vars": ["AWS_SECRET_ACCESS_KEY"], + "strict": False, + }, + "aws_session_token": {"type": "env_var", "env_vars": ["AWS_SESSION_TOKEN"], "strict": False}, + "aws_region_name": {"type": "env_var", "env_vars": ["AWS_DEFAULT_REGION"], "strict": False}, + "aws_profile_name": {"type": "env_var", "env_vars": ["AWS_PROFILE"], "strict": False}, + "model": "anthropic.claude-3-5-sonnet-20240620-v1:0", + "generation_kwargs": {"temperature": 0.7}, + "stop_words": ["eviscerate"], + "streaming_callback": "haystack.components.generators.utils.print_streaming_chunk", + "boto3_config": None, + "tools": [ + { + "type": "haystack.tools.tool.Tool", + "data": { + "name": "weather", + "description": "useful to determine the weather in a given location", + "parameters": {"city": {"type": "string"}}, + "function": "tests.test_chat_generator.weather", + }, + } + ], + }, + } + }, + "connections": [], + } + + # Test YAML serialization/deserialization + pipeline_yaml = pipeline.dumps() + new_pipeline = Pipeline.loads(pipeline_yaml) + assert new_pipeline == pipeline + + # Verify the loaded pipeline's generator has the same configuration + loaded_generator = new_pipeline.get_component("generator") + assert loaded_generator.model == generator.model + assert loaded_generator.generation_kwargs == generator.generation_kwargs + assert loaded_generator.streaming_callback == generator.streaming_callback + assert len(loaded_generator.tools) == len(generator.tools) + assert loaded_generator.tools[0].name == generator.tools[0].name + assert loaded_generator.tools[0].description == generator.tools[0].description + assert loaded_generator.tools[0].parameters == generator.tools[0].parameters From 8b9d0e9199eab67ef6c3f50b37ae685af2342e76 Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Thu, 23 Jan 2025 17:13:53 +0100 Subject: [PATCH 211/229] refactor!: AmazonBedrockGenerator - remove truncation (#1314) * AmazonBedrockGenerator - remove truncation * Update tests * Linting * Remove deprecation test * Update docs config * Update integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/generator.py Co-authored-by: Stefano Fiorucci * Lint --------- Co-authored-by: Stefano Fiorucci --- integrations/amazon_bedrock/pydoc/config.yml | 2 +- integrations/amazon_bedrock/pyproject.toml | 3 +- .../generators/amazon_bedrock/generator.py | 57 ++------ .../generators/amazon_bedrock/handlers.py | 62 --------- integrations/amazon_bedrock/tests/conftest.py | 8 -- .../amazon_bedrock/tests/test_generator.py | 126 +----------------- 6 files changed, 16 insertions(+), 242 deletions(-) delete mode 100644 integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/handlers.py diff --git a/integrations/amazon_bedrock/pydoc/config.yml b/integrations/amazon_bedrock/pydoc/config.yml index e6d6eba78..3f70c9c75 100644 --- a/integrations/amazon_bedrock/pydoc/config.yml +++ b/integrations/amazon_bedrock/pydoc/config.yml @@ -7,8 +7,8 @@ loaders: "haystack_integrations.components.embedders.amazon_bedrock.text_embedder", "haystack_integrations.components.generators.amazon_bedrock.generator", "haystack_integrations.components.generators.amazon_bedrock.adapters", + "haystack_integrations.common.amazon_bedrock.errors", "haystack_integrations.components.generators.amazon_bedrock.chat.chat_generator", - "haystack_integrations.components.generators.amazon_bedrock.handlers", "haystack_integrations.components.rankers.amazon_bedrock.ranker", ] ignore_when_discovered: ["__init__"] diff --git a/integrations/amazon_bedrock/pyproject.toml b/integrations/amazon_bedrock/pyproject.toml index 07f8db679..3e0a1c1b5 100644 --- a/integrations/amazon_bedrock/pyproject.toml +++ b/integrations/amazon_bedrock/pyproject.toml @@ -23,7 +23,7 @@ classifiers = [ "Programming Language :: Python :: Implementation :: CPython", "Programming Language :: Python :: Implementation :: PyPy", ] -dependencies = ["haystack-ai", "boto3>=1.28.57", "transformers!=4.48.*"] +dependencies = ["haystack-ai", "boto3>=1.28.57"] [project.urls] Documentation = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/amazon_bedrock#readme" @@ -156,7 +156,6 @@ exclude_lines = ["no cov", "if __name__ == .__main__.:", "if TYPE_CHECKING:"] [[tool.mypy.overrides]] module = [ "botocore.*", - "transformers.*", "boto3.*", "haystack.*", "haystack_integrations.*", diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/generator.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/generator.py index 79dc07cdc..d87c3aba1 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/generator.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/generator.py @@ -1,6 +1,7 @@ import json import logging import re +import warnings from typing import Any, Callable, ClassVar, Dict, List, Literal, Optional, Type, get_args from botocore.config import Config @@ -25,9 +26,6 @@ MetaLlamaAdapter, MistralAdapter, ) -from .handlers import ( - DefaultPromptHandler, -) logger = logging.getLogger(__name__) @@ -105,8 +103,8 @@ def __init__( aws_session_token: Optional[Secret] = Secret.from_env_var("AWS_SESSION_TOKEN", strict=False), # noqa: B008 aws_region_name: Optional[Secret] = Secret.from_env_var("AWS_DEFAULT_REGION", strict=False), # noqa: B008 aws_profile_name: Optional[Secret] = Secret.from_env_var("AWS_PROFILE", strict=False), # noqa: B008 - max_length: Optional[int] = 100, - truncate: Optional[bool] = True, + max_length: Optional[int] = None, + truncate: Optional[bool] = None, streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, boto3_config: Optional[Dict[str, Any]] = None, model_family: Optional[MODEL_FAMILIES] = None, @@ -121,8 +119,8 @@ def __init__( :param aws_session_token: The AWS session token. :param aws_region_name: The AWS region name. Make sure the region you set supports Amazon Bedrock. :param aws_profile_name: The AWS profile name. - :param max_length: The maximum length of the generated text. - :param truncate: Whether to truncate the prompt or not. + :param max_length: Deprecated. This parameter no longer has any effect. + :param truncate: Deprecated. This parameter no longer has any effect. :param streaming_callback: A callback function that is called when a new token is received from the stream. The callback function accepts StreamingChunk as an argument. :param boto3_config: The configuration for the boto3 client. @@ -140,6 +138,14 @@ def __init__( self.model = model self.max_length = max_length self.truncate = truncate + + if max_length is not None or truncate is not None: + warnings.warn( + "The 'max_length' and 'truncate' parameters have been removed and no longer have any effect. " + "No truncation will be performed.", + stacklevel=2, + ) + self.aws_access_key_id = aws_access_key_id self.aws_secret_access_key = aws_secret_access_key self.aws_session_token = aws_session_token @@ -173,44 +179,10 @@ def resolve_secret(secret: Optional[Secret]) -> Optional[str]: raise AmazonBedrockConfigurationError(msg) from exception model_input_kwargs = kwargs - # We pop the model_max_length as it is not sent to the model but used to truncate the prompt if needed - model_max_length = kwargs.get("model_max_length", 4096) - - # we initialize the prompt handler only if truncate is True: we avoid unnecessarily downloading the tokenizer - if self.truncate: - # Truncate prompt if prompt tokens > model_max_length-max_length - # (max_length is the length of the generated text) - # we use GPT2 tokenizer which will likely provide good token count approximation - self.prompt_handler = DefaultPromptHandler( - tokenizer="gpt2", - model_max_length=model_max_length, - max_length=self.max_length or 100, - ) model_adapter_cls = self.get_model_adapter(model=model, model_family=model_family) self.model_adapter = model_adapter_cls(model_kwargs=model_input_kwargs, max_length=self.max_length) - def _ensure_token_limit(self, prompt: str) -> str: - """ - Ensures that the prompt and answer token lengths together are within the model_max_length specified during - the initialization of the component. - - :param prompt: The prompt to be sent to the model. - :returns: The resized prompt. - """ - resize_info = self.prompt_handler(prompt) - if resize_info["prompt_length"] != resize_info["new_prompt_length"]: - logger.warning( - "The prompt was truncated from %s tokens to %s tokens so that the prompt length and " - "the answer length (%s tokens) fit within the model's max token limit (%s tokens). " - "Shorten the prompt or it will be cut off.", - resize_info["prompt_length"], - max(0, resize_info["model_max_length"] - resize_info["max_length"]), # type: ignore - resize_info["max_length"], - resize_info["model_max_length"], - ) - return str(resize_info["resized_prompt"]) - @component.output_types(replies=List[str]) def run( self, @@ -235,9 +207,6 @@ def run( streaming_callback = streaming_callback or self.streaming_callback generation_kwargs["stream"] = streaming_callback is not None - if self.truncate: - prompt = self._ensure_token_limit(prompt) - body = self.model_adapter.prepare_body(prompt=prompt, **generation_kwargs) try: if streaming_callback: diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/handlers.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/handlers.py deleted file mode 100644 index 07db2742f..000000000 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/handlers.py +++ /dev/null @@ -1,62 +0,0 @@ -from typing import Dict, Union - -from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerBase, PreTrainedTokenizerFast - - -class DefaultPromptHandler: - """ - DefaultPromptHandler resizes the prompt to ensure that the prompt and answer token lengths together - are within the model_max_length. - """ - - def __init__(self, tokenizer: Union[str, PreTrainedTokenizerBase], model_max_length: int, max_length: int = 100): - """ - :param tokenizer: The tokenizer to be used to tokenize the prompt. - :param model_max_length: The maximum length of the prompt and answer tokens combined. - :param max_length: The maximum length of the answer tokens. - :raises ValueError: If the tokenizer is not a string or a `PreTrainedTokenizer` or `PreTrainedTokenizerFast` - instance. - """ - if isinstance(tokenizer, str): - self.tokenizer = AutoTokenizer.from_pretrained(tokenizer) - elif isinstance(tokenizer, (PreTrainedTokenizer, PreTrainedTokenizerFast)): - self.tokenizer = tokenizer - else: - msg = "model must be a string or a PreTrainedTokenizer instance" - raise ValueError(msg) - - self.tokenizer.model_max_length = model_max_length - self.model_max_length = model_max_length - self.max_length = max_length - - def __call__(self, prompt: str, **kwargs) -> Dict[str, Union[str, int]]: - """ - Resizes the prompt to ensure that the prompt and answer is within the model_max_length - - :param prompt: the prompt to be sent to the model. - :param kwargs: Additional keyword arguments passed to the handler. - :returns: A dictionary containing the resized prompt and additional information. - """ - resized_prompt = prompt - prompt_length = 0 - new_prompt_length = 0 - - if prompt: - tokenized_prompt = self.tokenizer.tokenize(prompt) - prompt_length = len(tokenized_prompt) - if (prompt_length + self.max_length) <= self.model_max_length: - resized_prompt = prompt - new_prompt_length = prompt_length - else: - resized_prompt = self.tokenizer.convert_tokens_to_string( - tokenized_prompt[: self.model_max_length - self.max_length] - ) - new_prompt_length = len(tokenized_prompt[: self.model_max_length - self.max_length]) - - return { - "resized_prompt": resized_prompt, - "prompt_length": prompt_length, - "new_prompt_length": new_prompt_length, - "model_max_length": self.model_max_length, - "max_length": self.max_length, - } diff --git a/integrations/amazon_bedrock/tests/conftest.py b/integrations/amazon_bedrock/tests/conftest.py index 9406559bf..e744ea623 100644 --- a/integrations/amazon_bedrock/tests/conftest.py +++ b/integrations/amazon_bedrock/tests/conftest.py @@ -17,11 +17,3 @@ def set_env_variables(monkeypatch): def mock_boto3_session(): with patch("boto3.Session") as mock_client: yield mock_client - - -@pytest.fixture -def mock_prompt_handler(): - with patch( - "haystack_integrations.components.generators.amazon_bedrock.handlers.DefaultPromptHandler" - ) as mock_prompt_handler: - yield mock_prompt_handler diff --git a/integrations/amazon_bedrock/tests/test_generator.py b/integrations/amazon_bedrock/tests/test_generator.py index 3d2cbc01f..e168dc106 100644 --- a/integrations/amazon_bedrock/tests/test_generator.py +++ b/integrations/amazon_bedrock/tests/test_generator.py @@ -1,5 +1,5 @@ from typing import Any, Dict, Optional, Type -from unittest.mock import MagicMock, call, patch +from unittest.mock import MagicMock, call import pytest from haystack.dataclasses import StreamingChunk @@ -107,9 +107,6 @@ def test_default_constructor(mock_boto3_session, set_env_variables): assert layer.max_length == 99 assert layer.model == "anthropic.claude-v2" - assert layer.prompt_handler is not None - assert layer.prompt_handler.model_max_length == 4096 - # assert mocked boto3 client called exactly once mock_boto3_session.assert_called_once() @@ -123,23 +120,6 @@ def test_default_constructor(mock_boto3_session, set_env_variables): ) -def test_constructor_prompt_handler_initialized(mock_boto3_session, mock_prompt_handler): - """ - Test that the constructor sets the prompt_handler correctly, with the correct model_max_length for llama-2 - """ - layer = AmazonBedrockGenerator(model="anthropic.claude-v2", prompt_handler=mock_prompt_handler) - assert layer.prompt_handler is not None - assert layer.prompt_handler.model_max_length == 4096 - - -def test_prompt_handler_absent_when_truncate_false(mock_boto3_session): - """ - Test that the prompt_handler is not initialized when truncate is set to False. - """ - generator = AmazonBedrockGenerator(model="anthropic.claude-v2", truncate=False) - assert not hasattr(generator, "prompt_handler") - - def test_constructor_with_model_kwargs(mock_boto3_session): """ Test that model_kwargs are correctly set in the constructor @@ -159,110 +139,6 @@ def test_constructor_with_empty_model(): AmazonBedrockGenerator(model="") -def test_short_prompt_is_not_truncated(mock_boto3_session): - """ - Test that a short prompt is not truncated - """ - # Define a short mock prompt and its tokenized version - mock_prompt_text = "I am a tokenized prompt" - mock_prompt_tokens = mock_prompt_text.split() - - # Mock the tokenizer so it returns our predefined tokens - mock_tokenizer = MagicMock() - mock_tokenizer.tokenize.return_value = mock_prompt_tokens - - # We set a small max_length for generated text (3 tokens) and a total model_max_length of 10 tokens - # Since our mock prompt is 5 tokens long, it doesn't exceed the - # total limit (5 prompt tokens + 3 generated tokens < 10 tokens) - max_length_generated_text = 3 - total_model_max_length = 10 - - with patch("transformers.AutoTokenizer.from_pretrained", return_value=mock_tokenizer): - layer = AmazonBedrockGenerator( - "anthropic.claude-v2", - max_length=max_length_generated_text, - model_max_length=total_model_max_length, - ) - prompt_after_resize = layer._ensure_token_limit(mock_prompt_text) - - # The prompt doesn't exceed the limit, _ensure_token_limit doesn't truncate it - assert prompt_after_resize == mock_prompt_text - - -def test_long_prompt_is_truncated(mock_boto3_session): - """ - Test that a long prompt is truncated - """ - # Define a long mock prompt and its tokenized version - long_prompt_text = "I am a tokenized prompt of length eight" - long_prompt_tokens = long_prompt_text.split() - - # _ensure_token_limit will truncate the prompt to make it fit into the model's max token limit - truncated_prompt_text = "I am a tokenized prompt of length" - - # Mock the tokenizer to return our predefined tokens - # convert tokens to our predefined truncated text - mock_tokenizer = MagicMock() - mock_tokenizer.tokenize.return_value = long_prompt_tokens - mock_tokenizer.convert_tokens_to_string.return_value = truncated_prompt_text - - # We set a small max_length for generated text (3 tokens) and a total model_max_length of 10 tokens - # Our mock prompt is 8 tokens long, so it exceeds the total limit (8 prompt tokens + 3 generated tokens > 10 tokens) - max_length_generated_text = 3 - total_model_max_length = 10 - - with patch("transformers.AutoTokenizer.from_pretrained", return_value=mock_tokenizer): - layer = AmazonBedrockGenerator( - "anthropic.claude-v2", - max_length=max_length_generated_text, - model_max_length=total_model_max_length, - ) - prompt_after_resize = layer._ensure_token_limit(long_prompt_text) - - # The prompt exceeds the limit, _ensure_token_limit truncates it - assert prompt_after_resize == truncated_prompt_text - - -def test_long_prompt_is_not_truncated_when_truncate_false(mock_boto3_session): - """ - Test that a long prompt is not truncated and _ensure_token_limit is not called when truncate is set to False - """ - long_prompt_text = "I am a tokenized prompt of length eight" - - # Our mock prompt is 8 tokens long, so it exceeds the total limit (8 prompt tokens + 3 generated tokens > 10 tokens) - max_length_generated_text = 3 - total_model_max_length = 10 - - with patch("transformers.AutoTokenizer.from_pretrained", return_value=MagicMock()): - generator = AmazonBedrockGenerator( - model="anthropic.claude-v2", - max_length=max_length_generated_text, - model_max_length=total_model_max_length, - truncate=False, - ) - - # Mock the _ensure_token_limit method to track if it is called - with patch.object( - generator, "_ensure_token_limit", wraps=generator._ensure_token_limit - ) as mock_ensure_token_limit: - # Mock the model adapter to avoid actual invocation - generator.model_adapter.prepare_body = MagicMock(return_value={}) - generator.client = MagicMock() - generator.client.invoke_model = MagicMock( - return_value={"body": MagicMock(read=MagicMock(return_value=b'{"generated_text": "response"}'))} - ) - generator.model_adapter.get_responses = MagicMock(return_value=["response"]) - - # Invoke the generator - generator.run(prompt=long_prompt_text) - - # Ensure _ensure_token_limit was not called - mock_ensure_token_limit.assert_not_called() - - # Check the prompt passed to prepare_body - generator.model_adapter.prepare_body.assert_called_with(prompt=long_prompt_text, stream=False) - - @pytest.mark.parametrize( "model, expected_model_adapter", [ From a77bb23a77b0414fa27dc17d005a49fbc2bea9e1 Mon Sep 17 00:00:00 2001 From: HaystackBot Date: Thu, 23 Jan 2025 16:15:24 +0000 Subject: [PATCH 212/229] Update the changelog --- integrations/amazon_bedrock/CHANGELOG.md | 162 ++++++++++++++++++++++- 1 file changed, 158 insertions(+), 4 deletions(-) diff --git a/integrations/amazon_bedrock/CHANGELOG.md b/integrations/amazon_bedrock/CHANGELOG.md index 7ba6b422c..de6a311f4 100644 --- a/integrations/amazon_bedrock/CHANGELOG.md +++ b/integrations/amazon_bedrock/CHANGELOG.md @@ -1,5 +1,17 @@ # Changelog +## [integrations/amazon_bedrock-v3.0.0] - 2025-01-23 + +### 🚀 Features + +- *(AWS Bedrock)* Add Cohere Reranker (#1291) +- AmazonBedrockChatGenerator - add tools support (#1304) + +### 🚜 Refactor + +- [**breaking**] AmazonBedrockGenerator - remove truncation (#1314) + + ## [integrations/amazon_bedrock-v2.1.3] - 2025-01-21 ### 🧹 Chores @@ -164,15 +176,23 @@ ### 🌀 Miscellaneous -- Chore: change the pydoc renderer class (#718) +- Change the pydoc renderer class (#718) - Adding support of "amazon.titan-embed-text-v2:0" (#735) +* adding support of "amazon.titan-embed-text-v2:0" + +* rectifying the format + +--------- + +Co-authored-by: Massimiliano Pippi + ## [integrations/amazon_bedrock-v0.7.1] - 2024-04-24 ### 🌀 Miscellaneous -- Chore: add license classifiers (#680) -- Fix: Fix streaming_callback serialization in AmazonBedrockChatGenerator (#685) +- Add license classifiers (#680) +- Fix streaming_callback serialization in AmazonBedrockChatGenerator (#685) ## [integrations/amazon_bedrock-v0.7.0] - 2024-04-16 @@ -189,6 +209,12 @@ - Remove references to Python 3.7 (#601) - [Bedrock] Added Amazon Bedrock examples (#635) +* add ChatGenerator example + +* add Generator, Embedders example + +* move system prompt from inference params to messages + ## [integrations/amazon_bedrock-v0.6.0] - 2024-03-11 ### 🚀 Features @@ -214,6 +240,8 @@ - Fix order of API docs (#447) +This PR will also push the docs to Readme + ### 📚 Documentation - Update category slug (#442) @@ -226,12 +254,106 @@ ### 🌀 Miscellaneous - Amazon bedrock: generate api docs (#326) + +* amazon bedrock: generate api docs + +* path upd + +* add dependency + +* Update amazon_bedrock.yml + +* add files - Adopt Secret to Amazon Bedrock (#416) -- Bedrock - remove `supports` method (#456) + +* initial import + +* addin Secret and fixing tests + +* cleaning + +* using staticmethod directly + +* removing ignore from B008 config and setting up in-line + +* addin Secret and fixing tests for chat component + +* addin Secret and fixing tests for the chat component + +* fixing to_dict from_dict + +* fixing to_dict from_dict to the chat component as well +- Bedrock - remove supports method (#456) - Bedrock refactoring (#455) + +* wip + +* Bedrock refactoring + +* rm wip embedder + +* bedrock - remove supports method + +* rename commons to common + +* fix pydoc config + +* more cleaning + +* lint + +* rename test module - Bedrock Text Embedder (#466) + +* wip + +* Bedrock refactoring + +* rm wip embedder + +* bedrock - remove supports method + +* rename commons to common + +* fix pydoc config + +* text embedder! + +* more cleaning + +* lint + +* rename test module - Bedrock Document Embedder (#468) +* wip + +* Bedrock refactoring + +* rm wip embedder + +* bedrock - remove supports method + +* rename commons to common + +* fix pydoc config + +* text embedder! + +* more cleaning + +* lint + +* first draft + +* bedrock document embedder + +* pydoc config + +* fix pydoc + +* add test for titan + ## [integrations/amazon_bedrock-v0.3.0] - 2024-01-30 ### 🧹 Chores @@ -246,4 +368,36 @@ - [Amazon Bedrock] Add AmazonBedrockGenerator (#153) +* draft AmazonBedrockGenerator + +* black + +* renaming and imports + +* hatch lint, errors, labeler + +* hatch + +* linter + +* add hatch-vcs for versioning + +* black + +* mypy + +* fix AmazonBedrockConfigurationError import + +* add transformers dependency + +* unit test marker + +* auto tokenizer fixture and imports + +* to_dict, from_dict, model_name renaming + +* black + +* simplify model_name check + From d6f7a73a4d2f2e720951976deaf37695a97a8380 Mon Sep 17 00:00:00 2001 From: Stefano Fiorucci Date: Thu, 23 Jan 2025 18:12:04 +0100 Subject: [PATCH 213/229] fix changelog (#1319) --- integrations/amazon_bedrock/CHANGELOG.md | 150 +---------------------- 1 file changed, 4 insertions(+), 146 deletions(-) diff --git a/integrations/amazon_bedrock/CHANGELOG.md b/integrations/amazon_bedrock/CHANGELOG.md index de6a311f4..1a5e9970e 100644 --- a/integrations/amazon_bedrock/CHANGELOG.md +++ b/integrations/amazon_bedrock/CHANGELOG.md @@ -176,23 +176,15 @@ ### 🌀 Miscellaneous -- Change the pydoc renderer class (#718) +- Chore: change the pydoc renderer class (#718) - Adding support of "amazon.titan-embed-text-v2:0" (#735) -* adding support of "amazon.titan-embed-text-v2:0" - -* rectifying the format - ---------- - -Co-authored-by: Massimiliano Pippi - ## [integrations/amazon_bedrock-v0.7.1] - 2024-04-24 ### 🌀 Miscellaneous -- Add license classifiers (#680) -- Fix streaming_callback serialization in AmazonBedrockChatGenerator (#685) +- Chore: add license classifiers (#680) +- Fix: Fix streaming_callback serialization in AmazonBedrockChatGenerator (#685) ## [integrations/amazon_bedrock-v0.7.0] - 2024-04-16 @@ -209,12 +201,6 @@ Co-authored-by: Massimiliano Pippi - Remove references to Python 3.7 (#601) - [Bedrock] Added Amazon Bedrock examples (#635) -* add ChatGenerator example - -* add Generator, Embedders example - -* move system prompt from inference params to messages - ## [integrations/amazon_bedrock-v0.6.0] - 2024-03-11 ### 🚀 Features @@ -240,8 +226,6 @@ Co-authored-by: Massimiliano Pippi - Fix order of API docs (#447) -This PR will also push the docs to Readme - ### 📚 Documentation - Update category slug (#442) @@ -254,106 +238,12 @@ This PR will also push the docs to Readme ### 🌀 Miscellaneous - Amazon bedrock: generate api docs (#326) - -* amazon bedrock: generate api docs - -* path upd - -* add dependency - -* Update amazon_bedrock.yml - -* add files - Adopt Secret to Amazon Bedrock (#416) - -* initial import - -* addin Secret and fixing tests - -* cleaning - -* using staticmethod directly - -* removing ignore from B008 config and setting up in-line - -* addin Secret and fixing tests for chat component - -* addin Secret and fixing tests for the chat component - -* fixing to_dict from_dict - -* fixing to_dict from_dict to the chat component as well -- Bedrock - remove supports method (#456) +- Bedrock - remove `supports` method (#456) - Bedrock refactoring (#455) - -* wip - -* Bedrock refactoring - -* rm wip embedder - -* bedrock - remove supports method - -* rename commons to common - -* fix pydoc config - -* more cleaning - -* lint - -* rename test module - Bedrock Text Embedder (#466) - -* wip - -* Bedrock refactoring - -* rm wip embedder - -* bedrock - remove supports method - -* rename commons to common - -* fix pydoc config - -* text embedder! - -* more cleaning - -* lint - -* rename test module - Bedrock Document Embedder (#468) -* wip - -* Bedrock refactoring - -* rm wip embedder - -* bedrock - remove supports method - -* rename commons to common - -* fix pydoc config - -* text embedder! - -* more cleaning - -* lint - -* first draft - -* bedrock document embedder - -* pydoc config - -* fix pydoc - -* add test for titan - ## [integrations/amazon_bedrock-v0.3.0] - 2024-01-30 ### 🧹 Chores @@ -368,36 +258,4 @@ This PR will also push the docs to Readme - [Amazon Bedrock] Add AmazonBedrockGenerator (#153) -* draft AmazonBedrockGenerator - -* black - -* renaming and imports - -* hatch lint, errors, labeler - -* hatch - -* linter - -* add hatch-vcs for versioning - -* black - -* mypy - -* fix AmazonBedrockConfigurationError import - -* add transformers dependency - -* unit test marker - -* auto tokenizer fixture and imports - -* to_dict, from_dict, model_name renaming - -* black - -* simplify model_name check - From 1ce24ceba0553e68b21d4cd3a8e175c91cbb1120 Mon Sep 17 00:00:00 2001 From: "David S. Batista" Date: Thu, 23 Jan 2025 18:36:28 +0100 Subject: [PATCH 214/229] docs: updating `amazon-bedrock` README, should also contain, Embedder and Ranker in the type (#1320) * updating amazon-bedrock readme * alphabetical order --- README.md | 62 +++++++++++++++++++++++++++---------------------------- 1 file changed, 31 insertions(+), 31 deletions(-) diff --git a/README.md b/README.md index 0f8b2f0ee..100e1b961 100644 --- a/README.md +++ b/README.md @@ -24,37 +24,37 @@ Please check out our [Contribution Guidelines](CONTRIBUTING.md) for all the deta [![License Compliance](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/CI_license_compliance.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/CI_license_compliance.yml) -| Package | Type | PyPi Package | Status | -|----------------------------------------------------------------------------------------------------------------|---------------------|----------------------------------------------------------------------------------------------------------------------------------------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| -| [amazon-bedrock-haystack](integrations/amazon_bedrock/) | Generator | [![PyPI - Version](https://img.shields.io/pypi/v/amazon-bedrock-haystack.svg)](https://pypi.org/project/amazon-bedrock-haystack) | [![Test / amazon_bedrock](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/amazon_bedrock.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/amazon_bedrock.yml) | -| [amazon-sagemaker-haystack](integrations/amazon_sagemaker/) | Generator | [![PyPI - Version](https://img.shields.io/pypi/v/amazon-sagemaker-haystack.svg)](https://pypi.org/project/amazon-sagemaker-haystack) | [![Test / amazon_sagemaker](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/amazon_sagemaker.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/amazon_sagemaker.yml) | -| [anthropic-haystack](integrations/anthropic/) | Generator | [![PyPI - Version](https://img.shields.io/pypi/v/anthropic-haystack.svg)](https://pypi.org/project/anthropic-haystack) | [![Test / anthropic](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/anthropic.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/anthropic.yml) | -| [astra-haystack](integrations/astra/) | Document Store | [![PyPI - Version](https://img.shields.io/pypi/v/astra-haystack.svg)](https://pypi.org/project/astra-haystack) | [![Test / astra](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/astra.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/astra.yml) | -| [chroma-haystack](integrations/chroma/) | Document Store | [![PyPI - Version](https://img.shields.io/pypi/v/chroma-haystack.svg)](https://pypi.org/project/chroma-haystack) | [![Test / chroma](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/chroma.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/chroma.yml) | -| [cohere-haystack](integrations/cohere/) | Embedder, Generator, Ranker | [![PyPI - Version](https://img.shields.io/pypi/v/cohere-haystack.svg)](https://pypi.org/project/cohere-haystack) | [![Test / cohere](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/cohere.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/cohere.yml) | -| [deepeval-haystack](integrations/deepeval/) | Evaluator | [![PyPI - Version](https://img.shields.io/pypi/v/deepeval-haystack.svg)](https://pypi.org/project/deepeval-haystack) | [![Test / deepeval](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/deepeval.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/deepeval.yml) | -| [elasticsearch-haystack](integrations/elasticsearch/) | Document Store | [![PyPI - Version](https://img.shields.io/pypi/v/elasticsearch-haystack.svg)](https://pypi.org/project/elasticsearch-haystack) | [![Test / elasticsearch](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/elasticsearch.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/elasticsearch.yml) | -| [fastembed-haystack](integrations/fastembed/) | Embedder, Ranker | [![PyPI - Version](https://img.shields.io/pypi/v/fastembed-haystack.svg)](https://pypi.org/project/fastembed-haystack/) | [![Test / fastembed](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/fastembed.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/fastembed.yml) | -| [google-ai-haystack](integrations/google_ai/) | Generator | [![PyPI - Version](https://img.shields.io/pypi/v/google-ai-haystack.svg)](https://pypi.org/project/google-ai-haystack) | [![Test / google-ai](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/google_ai.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/google_ai.yml) | -| [google-vertex-haystack](integrations/google_vertex/) | Generator | [![PyPI - Version](https://img.shields.io/pypi/v/google-vertex-haystack.svg)](https://pypi.org/project/google-vertex-haystack) | [![Test / google-vertex](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/google_vertex.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/google_vertex.yml) | -| [instructor-embedders-haystack](integrations/instructor_embedders/) | Embedder | [![PyPI - Version](https://img.shields.io/pypi/v/instructor-embedders-haystack.svg)](https://pypi.org/project/instructor-embedders-haystack) | [![Test / instructor-embedders](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/instructor_embedders.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/instructor_embedders.yml) | -| [jina-haystack](integrations/jina/) | Connector, Embedder, Ranker | [![PyPI - Version](https://img.shields.io/pypi/v/jina-haystack.svg)](https://pypi.org/project/jina-haystack) | [![Test / jina](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/jina.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/jina.yml) | -| [langfuse-haystack](integrations/langfuse/) | Tracer | [![PyPI - Version](https://img.shields.io/pypi/v/langfuse-haystack.svg?color=orange)](https://pypi.org/project/langfuse-haystack) | [![Test / langfuse](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/langfuse.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/langfuse.yml) | -| [llama-cpp-haystack](integrations/llama_cpp/) | Generator | [![PyPI - Version](https://img.shields.io/pypi/v/llama-cpp-haystack.svg?color=orange)](https://pypi.org/project/llama-cpp-haystack) | [![Test / llama-cpp](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/llama_cpp.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/llama_cpp.yml) | -| [mistral-haystack](integrations/mistral/) | Embedder, Generator | [![PyPI - Version](https://img.shields.io/pypi/v/mistral-haystack.svg)](https://pypi.org/project/mistral-haystack) | [![Test / mistral](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/mistral.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/mistral.yml) | -| [mongodb-atlas-haystack](integrations/mongodb_atlas/) | Document Store | [![PyPI - Version](https://img.shields.io/pypi/v/mongodb-atlas-haystack.svg?color=orange)](https://pypi.org/project/mongodb-atlas-haystack) | [![Test / mongodb-atlas](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/mongodb_atlas.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/mongodb_atlas.yml) | -| [nvidia-haystack](integrations/nvidia/) | Embedder, Generator, Ranker | [![PyPI - Version](https://img.shields.io/pypi/v/nvidia-haystack.svg?color=orange)](https://pypi.org/project/nvidia-haystack) | [![Test / nvidia](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/nvidia.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/nvidia.yml) | -| [ollama-haystack](integrations/ollama/) | Embedder, Generator | [![PyPI - Version](https://img.shields.io/pypi/v/ollama-haystack.svg?color=orange)](https://pypi.org/project/ollama-haystack) | [![Test / ollama](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/ollama.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/ollama.yml) | -| [opensearch-haystack](integrations/opensearch/) | Document Store | [![PyPI - Version](https://img.shields.io/pypi/v/opensearch-haystack.svg)](https://pypi.org/project/opensearch-haystack) | [![Test / opensearch](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/opensearch.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/opensearch.yml) | -| [optimum-haystack](integrations/optimum/) | Embedder | [![PyPI - Version](https://img.shields.io/pypi/v/optimum-haystack.svg)](https://pypi.org/project/optimum-haystack) | [![Test / optimum](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/optimum.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/optimum.yml) | -| [pinecone-haystack](integrations/pinecone/) | Document Store | [![PyPI - Version](https://img.shields.io/pypi/v/pinecone-haystack.svg?color=orange)](https://pypi.org/project/pinecone-haystack) | [![Test / pinecone](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/pinecone.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/pinecone.yml) | -| [pgvector-haystack](integrations/pgvector/) | Document Store | [![PyPI - Version](https://img.shields.io/pypi/v/pgvector-haystack.svg?color=orange)](https://pypi.org/project/pgvector-haystack) | [![Test / pgvector](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/pgvector.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/pgvector.yml) | -| [qdrant-haystack](integrations/qdrant/) | Document Store | [![PyPI - Version](https://img.shields.io/pypi/v/qdrant-haystack.svg?color=orange)](https://pypi.org/project/qdrant-haystack) | [![Test / qdrant](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/qdrant.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/qdrant.yml) | -| [ragas-haystack](integrations/ragas/) | Evaluator | [![PyPI - Version](https://img.shields.io/pypi/v/ragas-haystack.svg)](https://pypi.org/project/ragas-haystack) | [![Test / ragas](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/ragas.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/ragas.yml) | -| [snowflake-haystack](integrations/snowflake/) | Retriever | [![PyPI - Version](https://img.shields.io/pypi/v/snowflake-haystack.svg)](https://pypi.org/project/snowflake-haystack) | [![Test / snowflake](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/snowflake.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/snowflake.yml) | -| [unstructured-fileconverter-haystack](integrations/unstructured/) | File converter | [![PyPI - Version](https://img.shields.io/pypi/v/unstructured-fileconverter-haystack.svg)](https://pypi.org/project/unstructured-fileconverter-haystack) | [![Test / unstructured](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/unstructured.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/unstructured.yml) | -| [uptrain-haystack](https://github.com/deepset-ai/haystack-core-integrations/tree/staging/integrations/uptrain) | Evaluator | [![PyPI - Version](https://img.shields.io/pypi/v/uptrain-haystack.svg)](https://pypi.org/project/uptrain-haystack) | [Staged](https://docs.haystack.deepset.ai/docs/breaking-change-policy#discontinuing-an-integration) | -| [weaviate-haystack](integrations/weaviate/) | Document Store | [![PyPI - Version](https://img.shields.io/pypi/v/weaviate-haystack.svg)](https://pypi.org/project/weaviate-haystack) | [![Test / weaviate](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/weaviate.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/weaviate.yml) | +| Package | Type | PyPi Package | Status | +|----------------------------------------------------------------------------------------------------------------|-------------------------------|----------------------------------------------------------------------------------------------------------------------------------------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| [amazon-bedrock-haystack](integrations/amazon_bedrock/) | Embedder, Generator, Ranker | [![PyPI - Version](https://img.shields.io/pypi/v/amazon-bedrock-haystack.svg)](https://pypi.org/project/amazon-bedrock-haystack) | [![Test / amazon_bedrock](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/amazon_bedrock.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/amazon_bedrock.yml) | +| [amazon-sagemaker-haystack](integrations/amazon_sagemaker/) | Generator | [![PyPI - Version](https://img.shields.io/pypi/v/amazon-sagemaker-haystack.svg)](https://pypi.org/project/amazon-sagemaker-haystack) | [![Test / amazon_sagemaker](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/amazon_sagemaker.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/amazon_sagemaker.yml) | +| [anthropic-haystack](integrations/anthropic/) | Generator | [![PyPI - Version](https://img.shields.io/pypi/v/anthropic-haystack.svg)](https://pypi.org/project/anthropic-haystack) | [![Test / anthropic](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/anthropic.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/anthropic.yml) | +| [astra-haystack](integrations/astra/) | Document Store | [![PyPI - Version](https://img.shields.io/pypi/v/astra-haystack.svg)](https://pypi.org/project/astra-haystack) | [![Test / astra](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/astra.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/astra.yml) | +| [chroma-haystack](integrations/chroma/) | Document Store | [![PyPI - Version](https://img.shields.io/pypi/v/chroma-haystack.svg)](https://pypi.org/project/chroma-haystack) | [![Test / chroma](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/chroma.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/chroma.yml) | +| [cohere-haystack](integrations/cohere/) | Embedder, Generator, Ranker | [![PyPI - Version](https://img.shields.io/pypi/v/cohere-haystack.svg)](https://pypi.org/project/cohere-haystack) | [![Test / cohere](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/cohere.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/cohere.yml) | +| [deepeval-haystack](integrations/deepeval/) | Evaluator | [![PyPI - Version](https://img.shields.io/pypi/v/deepeval-haystack.svg)](https://pypi.org/project/deepeval-haystack) | [![Test / deepeval](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/deepeval.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/deepeval.yml) | +| [elasticsearch-haystack](integrations/elasticsearch/) | Document Store | [![PyPI - Version](https://img.shields.io/pypi/v/elasticsearch-haystack.svg)](https://pypi.org/project/elasticsearch-haystack) | [![Test / elasticsearch](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/elasticsearch.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/elasticsearch.yml) | +| [fastembed-haystack](integrations/fastembed/) | Embedder, Ranker | [![PyPI - Version](https://img.shields.io/pypi/v/fastembed-haystack.svg)](https://pypi.org/project/fastembed-haystack/) | [![Test / fastembed](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/fastembed.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/fastembed.yml) | +| [google-ai-haystack](integrations/google_ai/) | Generator | [![PyPI - Version](https://img.shields.io/pypi/v/google-ai-haystack.svg)](https://pypi.org/project/google-ai-haystack) | [![Test / google-ai](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/google_ai.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/google_ai.yml) | +| [google-vertex-haystack](integrations/google_vertex/) | Generator | [![PyPI - Version](https://img.shields.io/pypi/v/google-vertex-haystack.svg)](https://pypi.org/project/google-vertex-haystack) | [![Test / google-vertex](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/google_vertex.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/google_vertex.yml) | +| [instructor-embedders-haystack](integrations/instructor_embedders/) | Embedder | [![PyPI - Version](https://img.shields.io/pypi/v/instructor-embedders-haystack.svg)](https://pypi.org/project/instructor-embedders-haystack) | [![Test / instructor-embedders](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/instructor_embedders.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/instructor_embedders.yml) | +| [jina-haystack](integrations/jina/) | Connector, Embedder, Ranker | [![PyPI - Version](https://img.shields.io/pypi/v/jina-haystack.svg)](https://pypi.org/project/jina-haystack) | [![Test / jina](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/jina.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/jina.yml) | +| [langfuse-haystack](integrations/langfuse/) | Tracer | [![PyPI - Version](https://img.shields.io/pypi/v/langfuse-haystack.svg?color=orange)](https://pypi.org/project/langfuse-haystack) | [![Test / langfuse](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/langfuse.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/langfuse.yml) | +| [llama-cpp-haystack](integrations/llama_cpp/) | Generator | [![PyPI - Version](https://img.shields.io/pypi/v/llama-cpp-haystack.svg?color=orange)](https://pypi.org/project/llama-cpp-haystack) | [![Test / llama-cpp](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/llama_cpp.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/llama_cpp.yml) | +| [mistral-haystack](integrations/mistral/) | Embedder, Generator | [![PyPI - Version](https://img.shields.io/pypi/v/mistral-haystack.svg)](https://pypi.org/project/mistral-haystack) | [![Test / mistral](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/mistral.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/mistral.yml) | +| [mongodb-atlas-haystack](integrations/mongodb_atlas/) | Document Store | [![PyPI - Version](https://img.shields.io/pypi/v/mongodb-atlas-haystack.svg?color=orange)](https://pypi.org/project/mongodb-atlas-haystack) | [![Test / mongodb-atlas](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/mongodb_atlas.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/mongodb_atlas.yml) | +| [nvidia-haystack](integrations/nvidia/) | Embedder, Generator, Ranker | [![PyPI - Version](https://img.shields.io/pypi/v/nvidia-haystack.svg?color=orange)](https://pypi.org/project/nvidia-haystack) | [![Test / nvidia](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/nvidia.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/nvidia.yml) | +| [ollama-haystack](integrations/ollama/) | Embedder, Generator | [![PyPI - Version](https://img.shields.io/pypi/v/ollama-haystack.svg?color=orange)](https://pypi.org/project/ollama-haystack) | [![Test / ollama](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/ollama.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/ollama.yml) | +| [opensearch-haystack](integrations/opensearch/) | Document Store | [![PyPI - Version](https://img.shields.io/pypi/v/opensearch-haystack.svg)](https://pypi.org/project/opensearch-haystack) | [![Test / opensearch](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/opensearch.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/opensearch.yml) | +| [optimum-haystack](integrations/optimum/) | Embedder | [![PyPI - Version](https://img.shields.io/pypi/v/optimum-haystack.svg)](https://pypi.org/project/optimum-haystack) | [![Test / optimum](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/optimum.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/optimum.yml) | +| [pinecone-haystack](integrations/pinecone/) | Document Store | [![PyPI - Version](https://img.shields.io/pypi/v/pinecone-haystack.svg?color=orange)](https://pypi.org/project/pinecone-haystack) | [![Test / pinecone](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/pinecone.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/pinecone.yml) | +| [pgvector-haystack](integrations/pgvector/) | Document Store | [![PyPI - Version](https://img.shields.io/pypi/v/pgvector-haystack.svg?color=orange)](https://pypi.org/project/pgvector-haystack) | [![Test / pgvector](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/pgvector.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/pgvector.yml) | +| [qdrant-haystack](integrations/qdrant/) | Document Store | [![PyPI - Version](https://img.shields.io/pypi/v/qdrant-haystack.svg?color=orange)](https://pypi.org/project/qdrant-haystack) | [![Test / qdrant](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/qdrant.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/qdrant.yml) | +| [ragas-haystack](integrations/ragas/) | Evaluator | [![PyPI - Version](https://img.shields.io/pypi/v/ragas-haystack.svg)](https://pypi.org/project/ragas-haystack) | [![Test / ragas](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/ragas.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/ragas.yml) | +| [snowflake-haystack](integrations/snowflake/) | Retriever | [![PyPI - Version](https://img.shields.io/pypi/v/snowflake-haystack.svg)](https://pypi.org/project/snowflake-haystack) | [![Test / snowflake](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/snowflake.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/snowflake.yml) | +| [unstructured-fileconverter-haystack](integrations/unstructured/) | File converter | [![PyPI - Version](https://img.shields.io/pypi/v/unstructured-fileconverter-haystack.svg)](https://pypi.org/project/unstructured-fileconverter-haystack) | [![Test / unstructured](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/unstructured.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/unstructured.yml) | +| [uptrain-haystack](https://github.com/deepset-ai/haystack-core-integrations/tree/staging/integrations/uptrain) | Evaluator | [![PyPI - Version](https://img.shields.io/pypi/v/uptrain-haystack.svg)](https://pypi.org/project/uptrain-haystack) | [Staged](https://docs.haystack.deepset.ai/docs/breaking-change-policy#discontinuing-an-integration) | +| [weaviate-haystack](integrations/weaviate/) | Document Store | [![PyPI - Version](https://img.shields.io/pypi/v/weaviate-haystack.svg)](https://pypi.org/project/weaviate-haystack) | [![Test / weaviate](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/weaviate.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/weaviate.yml) | ## Releasing From 7550f1b95a7eadbc2c70c56fb11e367b29cb2367 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Fri, 24 Jan 2025 14:58:18 +0100 Subject: [PATCH 215/229] chore(deps): bump fossas/fossa-action from 1.4.0 to 1.5.0 (#1322) Bumps [fossas/fossa-action](https://github.com/fossas/fossa-action) from 1.4.0 to 1.5.0. - [Release notes](https://github.com/fossas/fossa-action/releases) - [Commits](https://github.com/fossas/fossa-action/compare/v1.4.0...v1.5.0) --- updated-dependencies: - dependency-name: fossas/fossa-action dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- .github/workflows/CI_license_compliance.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/CI_license_compliance.yml b/.github/workflows/CI_license_compliance.yml index fc28706df..66a0f8958 100644 --- a/.github/workflows/CI_license_compliance.yml +++ b/.github/workflows/CI_license_compliance.yml @@ -76,7 +76,7 @@ jobs: # We keep the license inventory on FOSSA - name: Send license report to Fossa - uses: fossas/fossa-action@v1.4.0 + uses: fossas/fossa-action@v1.5.0 continue-on-error: true # not critical with: api-key: ${{ secrets.FOSSA_LICENSE_SCAN_TOKEN }} From 688301267d93a731f83a39a463c9925e0974126f Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Tue, 28 Jan 2025 11:45:41 +0100 Subject: [PATCH 216/229] feat: Add custom Langfuse span handling support (#1313) * Add SpanHandler * Add integration test * Pydocs * Pass operation_name to span handler * Add SpanContext * Improve SpanContext and pydocs * Use type alias for languse stateful clients, remove forward refs * Mention SpanHandler in class pydocs --- .../connectors/langfuse/langfuse_connector.py | 36 ++- .../tracing/langfuse/__init__.py | 4 +- .../tracing/langfuse/tracer.py | 261 ++++++++++++++---- integrations/langfuse/tests/test_tracing.py | 86 ++++++ 4 files changed, 328 insertions(+), 59 deletions(-) diff --git a/integrations/langfuse/src/haystack_integrations/components/connectors/langfuse/langfuse_connector.py b/integrations/langfuse/src/haystack_integrations/components/connectors/langfuse/langfuse_connector.py index 9418b2270..f016498d9 100644 --- a/integrations/langfuse/src/haystack_integrations/components/connectors/langfuse/langfuse_connector.py +++ b/integrations/langfuse/src/haystack_integrations/components/connectors/langfuse/langfuse_connector.py @@ -3,8 +3,9 @@ import httpx from haystack import component, default_from_dict, default_to_dict, logging, tracing from haystack.utils import Secret, deserialize_secrets_inplace +from haystack.utils.base_serialization import deserialize_class_instance, serialize_class_instance -from haystack_integrations.tracing.langfuse import LangfuseTracer +from haystack_integrations.tracing.langfuse import LangfuseTracer, SpanHandler from langfuse import Langfuse logger = logging.getLogger(__name__) @@ -94,6 +95,24 @@ async def shutdown_event(): print(response["tracer"]["trace_url"]) ``` + For advanced use cases, you can also customize how spans are created and processed by + providing a custom SpanHandler. This allows you to add custom metrics, set warning levels, + or attach additional metadata to your Langfuse traces: + + ```python + from haystack_integrations.tracing.langfuse import DefaultSpanHandler, LangfuseSpan + from typing import Optional + + class CustomSpanHandler(DefaultSpanHandler): + + def handle(self, span: LangfuseSpan, component_type: Optional[str]) -> None: + # Custom span handling logic, customize Langfuse spans however it fits you + # see DefaultSpanHandler for how we create and process spans by default + pass + + connector = LangfuseConnector(span_handler=CustomSpanHandler()) + ``` + """ def __init__( @@ -103,6 +122,7 @@ def __init__( public_key: Optional[Secret] = Secret.from_env_var("LANGFUSE_PUBLIC_KEY"), # noqa: B008 secret_key: Optional[Secret] = Secret.from_env_var("LANGFUSE_SECRET_KEY"), # noqa: B008 httpx_client: Optional[httpx.Client] = None, + span_handler: Optional[SpanHandler] = None, ): """ Initialize the LangfuseConnector component. @@ -117,11 +137,16 @@ def __init__( :param httpx_client: Optional custom httpx.Client instance to use for Langfuse API calls. Note that when deserializing a pipeline from YAML, any custom client is discarded and Langfuse will create its own default client, since HTTPX clients cannot be serialized. + :param span_handler: Optional custom handler for processing spans. If None, uses DefaultSpanHandler. + The span handler controls how spans are created and processed, allowing customization of span types + based on component types and additional processing after spans are yielded. See SpanHandler class for + details on implementing custom handlers. """ self.name = name self.public = public self.secret_key = secret_key self.public_key = public_key + self.span_handler = span_handler self.tracer = LangfuseTracer( tracer=Langfuse( secret_key=secret_key.resolve_value() if secret_key else None, @@ -130,6 +155,7 @@ def __init__( ), name=name, public=public, + span_handler=span_handler, ) tracing.enable_tracing(self.tracer) @@ -158,6 +184,7 @@ def to_dict(self) -> Dict[str, Any]: :returns: The serialized component as a dictionary. """ + span_handler = serialize_class_instance(self.span_handler) if self.span_handler else None return default_to_dict( self, name=self.name, @@ -165,6 +192,7 @@ def to_dict(self) -> Dict[str, Any]: secret_key=self.secret_key.to_dict() if self.secret_key else None, public_key=self.public_key.to_dict() if self.public_key else None, # Note: httpx_client is not serialized as it's not serializable + span_handler=span_handler, ) @classmethod @@ -175,5 +203,9 @@ def from_dict(cls, data: Dict[str, Any]) -> "LangfuseConnector": :param data: The dictionary representation of this component. :returns: The deserialized component instance. """ - deserialize_secrets_inplace(data["init_parameters"], keys=["secret_key", "public_key"]) + init_params = data["init_parameters"] + deserialize_secrets_inplace(init_params, keys=["secret_key", "public_key"]) + init_params["span_handler"] = ( + deserialize_class_instance(init_params["span_handler"]) if init_params["span_handler"] else None + ) return default_from_dict(cls, data) diff --git a/integrations/langfuse/src/haystack_integrations/tracing/langfuse/__init__.py b/integrations/langfuse/src/haystack_integrations/tracing/langfuse/__init__.py index e7331852d..6d8ea250f 100644 --- a/integrations/langfuse/src/haystack_integrations/tracing/langfuse/__init__.py +++ b/integrations/langfuse/src/haystack_integrations/tracing/langfuse/__init__.py @@ -1,6 +1,6 @@ # SPDX-FileCopyrightText: 2024-present deepset GmbH # # SPDX-License-Identifier: Apache-2.0 -from .tracer import LangfuseTracer +from .tracer import DefaultSpanHandler, LangfuseSpan, LangfuseTracer, SpanContext, SpanHandler -__all__ = ["LangfuseTracer"] +__all__ = ["DefaultSpanHandler", "LangfuseSpan", "LangfuseTracer", "SpanContext", "SpanHandler"] diff --git a/integrations/langfuse/src/haystack_integrations/tracing/langfuse/tracer.py b/integrations/langfuse/src/haystack_integrations/tracing/langfuse/tracer.py index 50fab2c8a..530a29862 100644 --- a/integrations/langfuse/src/haystack_integrations/tracing/langfuse/tracer.py +++ b/integrations/langfuse/src/haystack_integrations/tracing/langfuse/tracer.py @@ -1,16 +1,24 @@ import contextlib import os +from abc import ABC, abstractmethod from contextvars import ContextVar +from dataclasses import dataclass from datetime import datetime from typing import Any, Dict, Iterator, List, Optional, Union -from haystack import logging +from haystack import default_from_dict, default_to_dict, logging from haystack.dataclasses import ChatMessage from haystack.tracing import Span, Tracer from haystack.tracing import tracer as proxy_tracer from haystack.tracing import utils as tracing_utils +from typing_extensions import TypeAlias import langfuse +from langfuse.client import StatefulGenerationClient, StatefulSpanClient, StatefulTraceClient + +# Type alias for Langfuse stateful clients +LangfuseStatefulClient: TypeAlias = Union[StatefulTraceClient, StatefulSpanClient, StatefulGenerationClient] + logger = logging.getLogger(__name__) @@ -50,7 +58,7 @@ class LangfuseSpan(Span): Internal class representing a bridge between the Haystack span tracing API and Langfuse. """ - def __init__(self, span: "Union[langfuse.client.StatefulSpanClient, langfuse.client.StatefulTraceClient]") -> None: + def __init__(self, span: LangfuseStatefulClient) -> None: """ Initialize a LangfuseSpan instance. @@ -98,7 +106,7 @@ def set_content_tag(self, key: str, value: Any) -> None: self._data[key] = value - def raw_span(self) -> "Union[langfuse.client.StatefulSpanClient, langfuse.client.StatefulTraceClient]": + def raw_span(self) -> LangfuseStatefulClient: """ Return the underlying span instance. @@ -110,12 +118,185 @@ def get_correlation_data_for_logs(self) -> Dict[str, Any]: return {} +@dataclass(frozen=True) +class SpanContext: + """ + Context for creating spans in Langfuse. + + Encapsulates the information needed to create and configure a span in Langfuse tracing. + Used by SpanHandler to determine the span type (trace, generation, or default) and its configuration. + + :param name: The name of the span to create. For components, this is typically the component name. + :param operation_name: The operation being traced (e.g. "haystack.pipeline.run"). Used to determine + if a new trace should be created without warning. + :param component_type: The type of component creating the span (e.g. "OpenAIChatGenerator"). + Can be used to determine the type of span to create. + :param tags: Additional metadata to attach to the span. Contains component input/output data + and other trace information. + :param parent_span: The parent span if this is a child span. If None, a new trace will be created. + :param trace_name: The name to use for the trace when creating a parent span. Defaults to "Haystack". + :param public: Whether traces should be publicly accessible. Defaults to False. + """ + + name: str + operation_name: str + component_type: Optional[str] + tags: Dict[str, Any] + parent_span: Optional[Span] + trace_name: str = "Haystack" + public: bool = False + + def __post_init__(self) -> None: + """ + Validate the span context attributes. + + :raises ValueError: If name, operation_name or trace_name are empty + :raises TypeError: If tags is not a dictionary + """ + if not self.name: + msg = "Span name cannot be empty" + raise ValueError(msg) + if not self.operation_name: + msg = "Operation name cannot be empty" + raise ValueError(msg) + if not self.trace_name: + msg = "Trace name cannot be empty" + raise ValueError(msg) + + +class SpanHandler(ABC): + """ + Abstract base class for customizing how Langfuse spans are created and processed. + + This class defines two key extension points: + 1. create_span: Controls what type of span to create (default or generation) + 2. handle: Processes the span after component execution (adding metadata, metrics, etc.) + + To implement a custom handler: + - Extend this class or DefaultSpanHandler + - Override create_span and handle methods. It is more common to override handle. + - Pass your handler to LangfuseConnector init method + """ + + def __init__(self): + self.tracer: Optional[langfuse.Langfuse] = None + + def init_tracer(self, tracer: langfuse.Langfuse) -> None: + """ + Initialize with Langfuse tracer. Called internally by LangfuseTracer. + + :param tracer: The Langfuse client instance to use for creating spans + """ + self.tracer = tracer + + @abstractmethod + def create_span(self, context: SpanContext) -> LangfuseSpan: + """ + Create a span of appropriate type based on the context. + + This method determines what kind of span to create: + - A new trace if there's no parent span + - A generation span for LLM components + - A default span for other components + + :param context: The context containing all information needed to create the span + :returns: A new LangfuseSpan instance configured according to the context + """ + pass + + @abstractmethod + def handle(self, span: LangfuseSpan, component_type: Optional[str]) -> None: + """ + Process a span after component execution by attaching metadata and metrics. + + This method is called after the component yields its span, allowing you to: + - Extract and attach token usage statistics + - Add model information + - Record timing data (e.g., time-to-first-token) + - Set log levels for quality monitoring + - Add custom metrics and observations + + :param span: The span that was yielded by the component + :param component_type: The type of component that created the span, used to determine + what metadata to extract and how to process it + """ + pass + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "SpanHandler": + return default_from_dict(cls, data) + + def to_dict(self) -> Dict[str, Any]: + return default_to_dict(self) + + +class DefaultSpanHandler(SpanHandler): + """DefaultSpanHandler provides the default Langfuse tracing behavior for Haystack.""" + + def create_span(self, context: SpanContext) -> LangfuseSpan: + message = "Tracer is not initialized" + if self.tracer is None: + raise RuntimeError(message) + tracing_ctx = tracing_context_var.get({}) + if not context.parent_span: + if context.operation_name != _PIPELINE_RUN_KEY: + logger.warning( + "Creating a new trace without a parent span is not recommended for operation '{operation_name}'.", + operation_name=context.operation_name, + ) + # Create a new trace when there's no parent span + return LangfuseSpan( + self.tracer.trace( + name=context.trace_name, + public=context.public, + id=tracing_ctx.get("trace_id"), + user_id=tracing_ctx.get("user_id"), + session_id=tracing_ctx.get("session_id"), + tags=tracing_ctx.get("tags"), + version=tracing_ctx.get("version"), + ) + ) + elif context.component_type in _ALL_SUPPORTED_GENERATORS: + return LangfuseSpan(context.parent_span.raw_span().generation(name=context.name)) + else: + return LangfuseSpan(context.parent_span.raw_span().span(name=context.name)) + + def handle(self, span: LangfuseSpan, component_type: Optional[str]) -> None: + if component_type in _SUPPORTED_GENERATORS: + meta = span._data.get(_COMPONENT_OUTPUT_KEY, {}).get("meta") + if meta: + m = meta[0] + span._span.update(usage=m.get("usage") or None, model=m.get("model")) + elif component_type in _SUPPORTED_CHAT_GENERATORS: + replies = span._data.get(_COMPONENT_OUTPUT_KEY, {}).get("replies") + if replies: + meta = replies[0].meta + completion_start_time = meta.get("completion_start_time") + if completion_start_time: + try: + completion_start_time = datetime.fromisoformat(completion_start_time) + except ValueError: + logger.error(f"Failed to parse completion_start_time: {completion_start_time}") + completion_start_time = None + span._span.update( + usage=meta.get("usage") or None, + model=meta.get("model"), + completion_start_time=completion_start_time, + ) + + class LangfuseTracer(Tracer): """ Internal class representing a bridge between the Haystack tracer and Langfuse. """ - def __init__(self, tracer: "langfuse.Langfuse", name: str = "Haystack", public: bool = False) -> None: + def __init__( + self, + tracer: langfuse.Langfuse, + name: str = "Haystack", + public: bool = False, + span_handler: Optional[SpanHandler] = None, + ) -> None: """ Initialize a LangfuseTracer instance. @@ -123,8 +304,9 @@ def __init__(self, tracer: "langfuse.Langfuse", name: str = "Haystack", public: :param name: The name of the pipeline or component. This name will be used to identify the tracing run on the Langfuse dashboard. :param public: Whether the tracing data should be public or private. If set to `True`, the tracing data will - be publicly accessible to anyone with the tracing URL. If set to `False`, the tracing data will be private - and only accessible to the Langfuse account owner. + be publicly accessible to anyone with the tracing URL. If set to `False`, the tracing data will be private + and only accessible to the Langfuse account owner. + :param span_handler: Custom handler for processing spans. If None, uses DefaultSpanHandler. """ if not proxy_tracer.is_content_tracing_enabled: logger.warning( @@ -137,6 +319,8 @@ def __init__(self, tracer: "langfuse.Langfuse", name: str = "Haystack", public: self._name = name self._public = public self.enforce_flush = os.getenv(HAYSTACK_LANGFUSE_ENFORCE_FLUSH_ENV_VAR, "true").lower() == "true" + self._span_handler = span_handler or DefaultSpanHandler() + self._span_handler.init_tracer(tracer) @contextlib.contextmanager def trace( @@ -144,68 +328,35 @@ def trace( ) -> Iterator[Span]: tags = tags or {} span_name = tags.get(_COMPONENT_NAME_KEY, operation_name) - - # Create new span depending whether there's a parent span or not - if not parent_span: - if operation_name != _PIPELINE_RUN_KEY: - logger.warning( - "Creating a new trace without a parent span is not recommended for operation '{operation_name}'.", - operation_name=operation_name, - ) - # Create a new trace if no parent span is provided - context = tracing_context_var.get({}) - span = LangfuseSpan( - self._tracer.trace( - name=self._name, - public=self._public, - id=context.get("trace_id"), - user_id=context.get("user_id"), - session_id=context.get("session_id"), - tags=context.get("tags"), - version=context.get("version"), - ) + component_type = tags.get(_COMPONENT_TYPE_KEY) + + # Create span using the handler + span = self._span_handler.create_span( + SpanContext( + name=span_name, + operation_name=operation_name, + component_type=component_type, + tags=tags, + parent_span=parent_span, + trace_name=self._name, + public=self._public, ) - elif tags.get(_COMPONENT_TYPE_KEY) in _ALL_SUPPORTED_GENERATORS: - span = LangfuseSpan(parent_span.raw_span().generation(name=span_name)) - else: - span = LangfuseSpan(parent_span.raw_span().span(name=span_name)) + ) self._context.append(span) span.set_tags(tags) yield span - # Update span metadata based on component type - if tags.get(_COMPONENT_TYPE_KEY) in _SUPPORTED_GENERATORS: - # Haystack returns one meta dict for each message, but the 'usage' value - # is always the same, let's just pick the first item - meta = span._data.get(_COMPONENT_OUTPUT_KEY, {}).get("meta") - if meta: - m = meta[0] - span._span.update(usage=m.get("usage") or None, model=m.get("model")) - elif tags.get(_COMPONENT_TYPE_KEY) in _SUPPORTED_CHAT_GENERATORS: - replies = span._data.get(_COMPONENT_OUTPUT_KEY, {}).get("replies") - if replies: - meta = replies[0].meta - completion_start_time = meta.get("completion_start_time") - if completion_start_time: - try: - completion_start_time = datetime.fromisoformat(completion_start_time) - except ValueError: - logger.error(f"Failed to parse completion_start_time: {completion_start_time}") - completion_start_time = None - span._span.update( - usage=meta.get("usage") or None, - model=meta.get("model"), - completion_start_time=completion_start_time, - ) + # Let the span handler process the span + self._span_handler.handle(span, component_type) raw_span = span.raw_span() # In this section, we finalize both regular spans and generation spans created using the LangfuseSpan class. # It's important to end() these spans to ensure they are properly closed and all relevant data is recorded. # Note that we do not call end() on the main trace span itself, as its lifecycle is managed differently. - if isinstance(raw_span, (langfuse.client.StatefulSpanClient, langfuse.client.StatefulGenerationClient)): + if isinstance(raw_span, (StatefulSpanClient, StatefulGenerationClient)): raw_span.end() self._context.pop() diff --git a/integrations/langfuse/tests/test_tracing.py b/integrations/langfuse/tests/test_tracing.py index 06f65e72c..77819f11d 100644 --- a/integrations/langfuse/tests/test_tracing.py +++ b/integrations/langfuse/tests/test_tracing.py @@ -1,6 +1,7 @@ import os import time from urllib.parse import urlparse +from typing import Optional import pytest import requests @@ -15,6 +16,8 @@ from haystack_integrations.components.connectors.langfuse import LangfuseConnector from haystack_integrations.components.generators.anthropic import AnthropicChatGenerator from haystack_integrations.components.generators.cohere import CohereChatGenerator +from haystack_integrations.tracing.langfuse import LangfuseSpan, DefaultSpanHandler +from haystack_integrations.tracing.langfuse.tracer import _COMPONENT_OUTPUT_KEY # don't remove (or move) this env var setting from here, it's needed to turn tracing on os.environ["HAYSTACK_CONTENT_TRACING_ENABLED"] = "true" @@ -169,3 +172,86 @@ def test_pipeline_serialization(monkeypatch): # Verify pipeline is the same assert new_pipe == pipe + + +class QualityCheckSpanHandler(DefaultSpanHandler): + """Extends default handler to add quality checks with warning levels.""" + + def handle(self, span: LangfuseSpan, component_type: Optional[str]) -> None: + # First do the default handling (model, usage, etc.) + super().handle(span, component_type) + + # Then add our custom quality checks + if component_type == "OpenAIChatGenerator": + output = span._data.get(_COMPONENT_OUTPUT_KEY, {}) + replies = output.get("replies", []) + + if not replies: + span._span.update(level="ERROR", status_message="No response received") + return + + reply = replies[0] + if "error" in reply.meta: + span._span.update(level="ERROR", status_message=f"OpenAI error: {reply.meta['error']}") + elif len(reply.text) > 10: + span._span.update(level="WARNING", status_message="Response too long (> 10 chars)") + else: + span._span.update(level="DEFAULT", status_message="Success") + + +@pytest.mark.integration +def test_custom_span_handler(): + """Test that custom span handler properly sets Langfuse levels.""" + if not all( + [os.environ.get("LANGFUSE_SECRET_KEY"), os.environ.get("LANGFUSE_PUBLIC_KEY"), os.environ.get("OPENAI_API_KEY")] + ): + pytest.skip("Missing required environment variables") + + pipe = Pipeline() + pipe.add_component( + "tracer", + LangfuseConnector( + name="Quality Check Example", + public=True, + span_handler=QualityCheckSpanHandler(), + ), + ) + pipe.add_component("prompt_builder", ChatPromptBuilder()) + pipe.add_component("llm", OpenAIChatGenerator()) + pipe.connect("prompt_builder.prompt", "llm.messages") + + # Test short response + messages = [ + ChatMessage.from_system("Respond with exactly 3 words."), + ChatMessage.from_user("What is Berlin?"), + ] + + response = pipe.run( + data={ + "prompt_builder": {"template_variables": {}, "template": messages}, + "tracer": {"invocation_context": {"user_id": "test_user"}}, + } + ) + + trace_url = response["tracer"]["trace_url"] + uuid = os.path.basename(urlparse(trace_url).path) + url = f"https://cloud.langfuse.com/api/public/traces/{uuid}" + + # Poll the Langfuse API + attempts = 5 + delay = 1 + while attempts >= 0: + res = requests.get( + url, auth=HTTPBasicAuth(os.environ["LANGFUSE_PUBLIC_KEY"], os.environ["LANGFUSE_SECRET_KEY"]) + ) + if attempts > 0 and res.status_code != 200: + attempts -= 1 + time.sleep(delay) + delay *= 2 + continue + + assert res.status_code == 200 + content = str(res.content) + assert "WARNING" in content + assert "Response too long" in content + break From 471e33ab0b7f3138e71f88cf27cf5adc0a163f34 Mon Sep 17 00:00:00 2001 From: HaystackBot Date: Tue, 28 Jan 2025 10:48:09 +0000 Subject: [PATCH 217/229] Update the changelog --- integrations/langfuse/CHANGELOG.md | 44 ++++++++++++++++++++++++------ 1 file changed, 35 insertions(+), 9 deletions(-) diff --git a/integrations/langfuse/CHANGELOG.md b/integrations/langfuse/CHANGELOG.md index 414e44e41..f2e230871 100644 --- a/integrations/langfuse/CHANGELOG.md +++ b/integrations/langfuse/CHANGELOG.md @@ -1,5 +1,12 @@ # Changelog +## [integrations/langfuse-v0.8.0] - 2025-01-28 + +### 🚀 Features + +- Add custom Langfuse span handling support (#1313) + + ## [integrations/langfuse-v0.7.0] - 2025-01-21 ### 🚀 Features @@ -22,7 +29,7 @@ ### 🌀 Miscellaneous -- Chore: Langfuse - pin `haystack-ai>=2.9.0` and simplify message conversion (#1292) +- Pin haystack-ai>=2.9.0 and simplify (#1292) ## [integrations/langfuse-v0.6.2] - 2025-01-02 @@ -36,7 +43,7 @@ ### 🌀 Miscellaneous -- Chore: Fix tracing_context_var lint errors (#1220) +- Fix tracing_context_var lint errors (#1220) - Fix messages conversion to OpenAI format (#1272) ## [integrations/langfuse-v0.6.0] - 2024-11-18 @@ -53,6 +60,25 @@ - Fixed TypeError in LangfuseTrace (#1184) +* Added parent_span functionality in trace method + +* solved PR comments + +* Readded "end()" for solving Latency issues + +* chore: fix ruff linting + +* Handle multiple runs + +* Fix indentation and span closing + +* Fix tests + +--------- + +Co-authored-by: Vladimir Blagojevic +Co-authored-by: Silvano Cerza + ## [integrations/langfuse-v0.5.0] - 2024-10-01 ### 🧹 Chores @@ -61,7 +87,7 @@ ### 🌀 Miscellaneous -- Fix: Add delay to flush the Langfuse traces (#1091) +- Add delay to flush the Langfuse traces (#1091) - Add invocation_context to identify traces (#1089) ## [integrations/langfuse-v0.4.0] - 2024-09-17 @@ -93,15 +119,15 @@ ### 🌀 Miscellaneous -- Ci: install `pytest-rerunfailures` where needed; add retry config to `test-cov` script (#845) -- Chore: Update Langfuse README to avoid common initialization issues (#952) -- Chore: langfuse - ruff update, don't ruff tests (#992) +- Install pytest-rerunfailures; change test-cov script (#845) +- Update Langfuse README to avoid common initialization issues (#952) +- Ruff update, don't ruff tests (#992) ## [integrations/langfuse-v0.2.0] - 2024-06-18 ### 🌀 Miscellaneous -- Feat: add support for Azure generators (#815) +- Add support for Azure generators (#815) ## [integrations/langfuse-v0.1.0] - 2024-06-13 @@ -119,7 +145,7 @@ ### 🌀 Miscellaneous -- Chore: change the pydoc renderer class (#718) -- Docs: add missing api references (#728) +- Change the pydoc renderer class (#718) +- Missing api references (#728) From ee1eb91f4ff27a05ef2e80a45c4d55a77d65c579 Mon Sep 17 00:00:00 2001 From: Siddharth Sahu <112792547+sahusiddharth@users.noreply.github.com> Date: Tue, 28 Jan 2025 17:53:26 +0530 Subject: [PATCH 218/229] feat: Update Ragas integration (#1312) * baseline ragas evaluator component * added test cases & made it robust * ready for review * fixed all the linting errors * fixed the failing test cases * moditied a testcase * improved error handling * modited the _get_expected_type_description function * Address review feedback --- .../evaluation_from_pipeline_example.py | 140 ++++++ .../evaluation_with_components_example.py | 117 +++++ integrations/ragas/example/example.py | 52 -- integrations/ragas/pydoc/config.yml | 5 +- integrations/ragas/pyproject.toml | 3 +- .../components/evaluators/ragas/__init__.py | 3 +- .../components/evaluators/ragas/evaluator.py | 362 +++++++++----- .../components/evaluators/ragas/metrics.py | 333 ------------- integrations/ragas/tests/test_evaluator.py | 456 +++++------------- integrations/ragas/tests/test_metrics.py | 11 - 10 files changed, 623 insertions(+), 859 deletions(-) create mode 100644 integrations/ragas/example/evaluation_from_pipeline_example.py create mode 100644 integrations/ragas/example/evaluation_with_components_example.py delete mode 100644 integrations/ragas/example/example.py delete mode 100644 integrations/ragas/src/haystack_integrations/components/evaluators/ragas/metrics.py delete mode 100644 integrations/ragas/tests/test_metrics.py diff --git a/integrations/ragas/example/evaluation_from_pipeline_example.py b/integrations/ragas/example/evaluation_from_pipeline_example.py new file mode 100644 index 000000000..5cb76b57f --- /dev/null +++ b/integrations/ragas/example/evaluation_from_pipeline_example.py @@ -0,0 +1,140 @@ +# A valid OpenAI API key must be provided as an environment variable "OPENAI_API_KEY" to run this example. + +import os +from getpass import getpass + +if "OPENAI_API_KEY" not in os.environ: + os.environ["OPENAI_API_KEY"] = getpass("Enter OpenAI API key:") + +from haystack import Document, Pipeline +from haystack.document_stores.in_memory import InMemoryDocumentStore +from haystack.components.embedders import OpenAITextEmbedder, OpenAIDocumentEmbedder +from haystack.components.retrievers.in_memory import InMemoryEmbeddingRetriever +from haystack.components.builders import ChatPromptBuilder +from haystack.dataclasses import ChatMessage +from haystack.components.generators.chat import OpenAIChatGenerator +from haystack.components.builders import AnswerBuilder + +from langchain_openai import ChatOpenAI +from ragas.llms import LangchainLLMWrapper +from ragas.metrics import AnswerRelevancy, ContextPrecision, Faithfulness +from ragas import evaluate +from ragas.dataset_schema import EvaluationDataset + + +dataset = [ + "OpenAI is one of the most recognized names in the large language model space, known for its GPT series of models. These models excel at generating human-like text and performing tasks like creative writing, answering questions, and summarizing content. GPT-4, their latest release, has set benchmarks in understanding context and delivering detailed responses.", + "Anthropic is well-known for its Claude series of language models, designed with a strong focus on safety and ethical AI behavior. Claude is particularly praised for its ability to follow complex instructions and generate text that aligns closely with user intent.", + "DeepMind, a division of Google, is recognized for its cutting-edge Gemini models, which are integrated into various Google products like Bard and Workspace tools. These models are renowned for their conversational abilities and their capacity to handle complex, multi-turn dialogues.", + "Meta AI is best known for its LLaMA (Large Language Model Meta AI) series, which has been made open-source for researchers and developers. LLaMA models are praised for their ability to support innovation and experimentation due to their accessibility and strong performance.", + "Meta AI with it's LLaMA models aims to democratize AI development by making high-quality models available for free, fostering collaboration across industries. Their open-source approach has been a game-changer for researchers without access to expensive resources.", + "Microsoft’s Azure AI platform is famous for integrating OpenAI’s GPT models, enabling businesses to use these advanced models in a scalable and secure cloud environment. Azure AI powers applications like Copilot in Office 365, helping users draft emails, generate summaries, and more.", + "Amazon’s Bedrock platform is recognized for providing access to various language models, including its own models and third-party ones like Anthropic’s Claude and AI21’s Jurassic. Bedrock is especially valued for its flexibility, allowing users to choose models based on their specific needs.", + "Cohere is well-known for its language models tailored for business use, excelling in tasks like search, summarization, and customer support. Their models are recognized for being efficient, cost-effective, and easy to integrate into workflows.", + "AI21 Labs is famous for its Jurassic series of language models, which are highly versatile and capable of handling tasks like content creation and code generation. The Jurassic models stand out for their natural language understanding and ability to generate detailed and coherent responses.", + "In the rapidly advancing field of artificial intelligence, several companies have made significant contributions with their large language models. Notable players include OpenAI, known for its GPT Series (including GPT-4); Anthropic, which offers the Claude Series; Google DeepMind with its Gemini Models; Meta AI, recognized for its LLaMA Series; Microsoft Azure AI, which integrates OpenAI’s GPT Models; Amazon AWS (Bedrock), providing access to various models including Claude (Anthropic) and Jurassic (AI21 Labs); Cohere, which offers its own models tailored for business use; and AI21 Labs, known for its Jurassic Series. These companies are shaping the landscape of AI by providing powerful models with diverse capabilities.", +] + +# Initialize components for RAG pipeline +document_store = InMemoryDocumentStore() +docs = [Document(content=doc) for doc in dataset] + +document_embedder = OpenAIDocumentEmbedder(model="text-embedding-3-small") +text_embedder = OpenAITextEmbedder(model="text-embedding-3-small") + +docs_with_embeddings = document_embedder.run(docs) +document_store.write_documents(docs_with_embeddings["documents"]) + +retriever = InMemoryEmbeddingRetriever(document_store, top_k=2) + +template = [ + ChatMessage.from_user( + """ +Given the following information, answer the question. + +Context: +{% for document in documents %} + {{ document.content }} +{% endfor %} + +Question: {{question}} +Answer: +""" + ) +] + +prompt_builder = ChatPromptBuilder(template=template) +chat_generator = OpenAIChatGenerator(model="gpt-4o-mini") + +# Creating the Pipeline +rag_pipeline = Pipeline() + +# Adding the components +rag_pipeline.add_component("text_embedder", text_embedder) +rag_pipeline.add_component("retriever", retriever) +rag_pipeline.add_component("prompt_builder", prompt_builder) +rag_pipeline.add_component("llm", chat_generator) +rag_pipeline.add_component("answer_builder", AnswerBuilder()) + +# Connecting the components +rag_pipeline.connect("text_embedder.embedding", "retriever.query_embedding") +rag_pipeline.connect("retriever", "prompt_builder") +rag_pipeline.connect("prompt_builder.prompt", "llm.messages") +rag_pipeline.connect("llm.replies", "answer_builder.replies") +rag_pipeline.connect("retriever", "answer_builder.documents") +rag_pipeline.connect("llm.replies", "answer_builder.replies") +rag_pipeline.connect("retriever", "answer_builder.documents") + + +questions = [ + "Who are the major players in the large language model space?", + "What is Microsoft’s Azure AI platform known for?", + "What kind of models does Cohere provide?", +] + +references = [ + "The major players include OpenAI (GPT Series), Anthropic (Claude Series), Google DeepMind (Gemini Models), Meta AI (LLaMA Series), Microsoft Azure AI (integrating GPT Models), Amazon AWS (Bedrock with Claude and Jurassic), Cohere (business-focused models), and AI21 Labs (Jurassic Series).", + "Microsoft’s Azure AI platform is known for integrating OpenAI’s GPT models, enabling businesses to use these models in a scalable and secure cloud environment.", + "Cohere provides language models tailored for business use, excelling in tasks like search, summarization, and customer support.", +] + + +evals_list = [] + +for que_idx in range(len(questions)): + + single_turn = {} + single_turn['user_input'] = questions[que_idx] + single_turn['reference'] = references[que_idx] + + # Running the pipeline + response = rag_pipeline.run( + { + "text_embedder": {"text": questions[que_idx]}, + "prompt_builder": {"question": questions[que_idx]}, + "answer_builder": {"query": questions[que_idx]}, + } + ) + + single_turn['response'] = response["answer_builder"]["answers"][0].data + + haystack_documents = response["answer_builder"]["answers"][0].documents + # extracting context from haystack documents + single_turn['retrieved_contexts'] = [doc.content for doc in haystack_documents] + + evals_list.append(single_turn) + +evaluation_dataset = EvaluationDataset.from_list(evals_list) + +llm = ChatOpenAI(model="gpt-4o-mini") +evaluator_llm = LangchainLLMWrapper(llm) + +result = evaluate( + dataset=evaluation_dataset, + metrics=[AnswerRelevancy(), ContextPrecision(), Faithfulness()], + llm=evaluator_llm, +) + +print(result) + +result.to_pandas() diff --git a/integrations/ragas/example/evaluation_with_components_example.py b/integrations/ragas/example/evaluation_with_components_example.py new file mode 100644 index 000000000..c01f8a9ac --- /dev/null +++ b/integrations/ragas/example/evaluation_with_components_example.py @@ -0,0 +1,117 @@ +# A valid OpenAI API key must be provided as an environment variable "OPENAI_API_KEY" to run this example. + +import os +from getpass import getpass + +if "OPENAI_API_KEY" not in os.environ: + os.environ["OPENAI_API_KEY"] = getpass("Enter OpenAI API key:") + + +from haystack import Document, Pipeline +from haystack.document_stores.in_memory import InMemoryDocumentStore +from haystack.components.embedders import OpenAITextEmbedder, OpenAIDocumentEmbedder +from haystack.components.retrievers.in_memory import InMemoryEmbeddingRetriever +from haystack.components.builders import ChatPromptBuilder +from haystack.dataclasses import ChatMessage +from haystack.components.generators.chat import OpenAIChatGenerator +from haystack.components.builders import AnswerBuilder +from haystack_integrations.components.evaluators.ragas import RagasEvaluator + +from langchain_openai import ChatOpenAI +from ragas.llms import LangchainLLMWrapper +from ragas.metrics import AnswerRelevancy, ContextPrecision, Faithfulness + + +dataset = [ + "OpenAI is one of the most recognized names in the large language model space, known for its GPT series of models. These models excel at generating human-like text and performing tasks like creative writing, answering questions, and summarizing content. GPT-4, their latest release, has set benchmarks in understanding context and delivering detailed responses.", + "Anthropic is well-known for its Claude series of language models, designed with a strong focus on safety and ethical AI behavior. Claude is particularly praised for its ability to follow complex instructions and generate text that aligns closely with user intent.", + "DeepMind, a division of Google, is recognized for its cutting-edge Gemini models, which are integrated into various Google products like Bard and Workspace tools. These models are renowned for their conversational abilities and their capacity to handle complex, multi-turn dialogues.", + "Meta AI is best known for its LLaMA (Large Language Model Meta AI) series, which has been made open-source for researchers and developers. LLaMA models are praised for their ability to support innovation and experimentation due to their accessibility and strong performance.", + "Meta AI with it's LLaMA models aims to democratize AI development by making high-quality models available for free, fostering collaboration across industries. Their open-source approach has been a game-changer for researchers without access to expensive resources.", + "Microsoft’s Azure AI platform is famous for integrating OpenAI’s GPT models, enabling businesses to use these advanced models in a scalable and secure cloud environment. Azure AI powers applications like Copilot in Office 365, helping users draft emails, generate summaries, and more.", + "Amazon’s Bedrock platform is recognized for providing access to various language models, including its own models and third-party ones like Anthropic’s Claude and AI21’s Jurassic. Bedrock is especially valued for its flexibility, allowing users to choose models based on their specific needs.", + "Cohere is well-known for its language models tailored for business use, excelling in tasks like search, summarization, and customer support. Their models are recognized for being efficient, cost-effective, and easy to integrate into workflows.", + "AI21 Labs is famous for its Jurassic series of language models, which are highly versatile and capable of handling tasks like content creation and code generation. The Jurassic models stand out for their natural language understanding and ability to generate detailed and coherent responses.", + "In the rapidly advancing field of artificial intelligence, several companies have made significant contributions with their large language models. Notable players include OpenAI, known for its GPT Series (including GPT-4); Anthropic, which offers the Claude Series; Google DeepMind with its Gemini Models; Meta AI, recognized for its LLaMA Series; Microsoft Azure AI, which integrates OpenAI’s GPT Models; Amazon AWS (Bedrock), providing access to various models including Claude (Anthropic) and Jurassic (AI21 Labs); Cohere, which offers its own models tailored for business use; and AI21 Labs, known for its Jurassic Series. These companies are shaping the landscape of AI by providing powerful models with diverse capabilities.", +] + +# Initialize components for RAG pipeline +document_store = InMemoryDocumentStore() +docs = [Document(content=doc) for doc in dataset] + +document_embedder = OpenAIDocumentEmbedder(model="text-embedding-3-small") +text_embedder = OpenAITextEmbedder(model="text-embedding-3-small") + +docs_with_embeddings = document_embedder.run(docs) +document_store.write_documents(docs_with_embeddings["documents"]) + +retriever = InMemoryEmbeddingRetriever(document_store, top_k=2) + +template = [ + ChatMessage.from_user( + """ +Given the following information, answer the question. + +Context: +{% for document in documents %} + {{ document.content }} +{% endfor %} + +Question: {{question}} +Answer: +""" + ) +] + +prompt_builder = ChatPromptBuilder(template=template) +chat_generator = OpenAIChatGenerator(model="gpt-4o-mini") + +# Setting the RagasEvaluator Component +llm = ChatOpenAI(model="gpt-4o-mini") +evaluator_llm = LangchainLLMWrapper(llm) + +ragas_evaluator = RagasEvaluator( + ragas_metrics=[AnswerRelevancy(), ContextPrecision(), Faithfulness()], evaluator_llm=evaluator_llm +) + +# Creating the Pipeline +rag_pipeline = Pipeline() + +# Adding the components +rag_pipeline.add_component("text_embedder", text_embedder) +rag_pipeline.add_component("retriever", retriever) +rag_pipeline.add_component("prompt_builder", prompt_builder) +rag_pipeline.add_component("llm", chat_generator) +rag_pipeline.add_component("answer_builder", AnswerBuilder()) +rag_pipeline.add_component("ragas_evaluator", ragas_evaluator) + +# Connecting the components +rag_pipeline.connect("text_embedder.embedding", "retriever.query_embedding") +rag_pipeline.connect("retriever", "prompt_builder") +rag_pipeline.connect("prompt_builder.prompt", "llm.messages") +rag_pipeline.connect("llm.replies", "answer_builder.replies") +rag_pipeline.connect("retriever", "answer_builder.documents") +rag_pipeline.connect("llm.replies", "answer_builder.replies") +rag_pipeline.connect("retriever", "answer_builder.documents") +rag_pipeline.connect("retriever", "ragas_evaluator.documents") +rag_pipeline.connect("llm.replies", "ragas_evaluator.response") + +# Run the pipeline +question = "What makes Meta AI’s LLaMA models stand out?" + +reference = "Meta AI’s LLaMA models stand out for being open-source, supporting innovation and experimentation due to their accessibility and strong performance." + + +result = rag_pipeline.run( + { + "text_embedder": {"text": question}, + "prompt_builder": {"question": question}, + "answer_builder": {"query": question}, + "ragas_evaluator": {"query": question, "reference": reference}, + # Each metric expects a specific set of parameters as input. Refer to the + # Ragas class' documentation for more details. + } +) + +print(result['answer_builder']['answers'][0].data, '\n') +print(result['ragas_evaluator']['result']) diff --git a/integrations/ragas/example/example.py b/integrations/ragas/example/example.py deleted file mode 100644 index ba75bdc7e..000000000 --- a/integrations/ragas/example/example.py +++ /dev/null @@ -1,52 +0,0 @@ -# A valid OpenAI API key must be provided as an environment variable "OPENAI_API_KEY" to run this example. - -from haystack import Pipeline - -from haystack_integrations.components.evaluators.ragas import RagasEvaluator, RagasMetric - -QUESTIONS = [ - "Which is the most popular global sport?", - "Who created the Python language?", -] -CONTEXTS = [ - [ - "The popularity of sports can be measured in various ways, including TV viewership, social media presence, number of participants, and economic impact. Football is undoubtedly the world's most popular sport with major events like the FIFA World Cup and sports personalities like Ronaldo and Messi, drawing a followership of more than 4 billion people." - ], - [ - "Python, created by Guido van Rossum in the late 1980s, is a high-level general-purpose programming language. Its design philosophy emphasizes code readability, and its language constructs aim to help programmers write clear, logical code for both small and large-scale software projects." - ], -] -RESPONSES = [ - "Football is the most popular sport with around 4 billion followers worldwide", - "Python language was created by Guido van Rossum.", -] - -GROUND_TRUTHS = [ - "Football is the most popular sport", - "Python language was created by Guido van Rossum.", -] - -pipeline = Pipeline() -evaluator_context = RagasEvaluator( - metric=RagasMetric.CONTEXT_PRECISION, -) -evaluator_aspect = RagasEvaluator( - metric=RagasMetric.ASPECT_CRITIQUE, - metric_params={"name": "custom", "definition": "Is this answer problematic for children?", "strictness": 3}, -) -pipeline.add_component("evaluator_context", evaluator_context) -pipeline.add_component("evaluator_aspect", evaluator_aspect) - -# Each metric expects a specific set of parameters as input. Refer to the -# Ragas class' documentation for more details. -results = pipeline.run( - { - "evaluator_context": {"questions": QUESTIONS, "contexts": CONTEXTS, "ground_truths": GROUND_TRUTHS}, - "evaluator_aspect": {"questions": QUESTIONS, "contexts": CONTEXTS, "responses": RESPONSES}, - } -) - - -for component in ["evaluator_context", "evaluator_aspect"]: - for output in results[component]["results"]: - print(output) diff --git a/integrations/ragas/pydoc/config.yml b/integrations/ragas/pydoc/config.yml index 97d8d808e..033bc738c 100644 --- a/integrations/ragas/pydoc/config.yml +++ b/integrations/ragas/pydoc/config.yml @@ -3,8 +3,7 @@ loaders: search_path: [../src] modules: [ - "haystack_integrations.components.evaluators.ragas.evaluator", - "haystack_integrations.components.evaluators.ragas.metrics", + "haystack_integrations.components.evaluators.ragas.evaluator" ] ignore_when_discovered: ["__init__"] processors: @@ -13,8 +12,6 @@ processors: documented_only: true do_not_filter_modules: false skip_empty_modules: true - - type: filter - expression: "name not in ['InputConverters', 'MetricDescriptor', 'MetricParamsValidators', 'OutputConverters', 'METRIC_DESCRIPTORS']" - type: smart - type: crossref renderer: diff --git a/integrations/ragas/pyproject.toml b/integrations/ragas/pyproject.toml index 179bcce16..548545e34 100644 --- a/integrations/ragas/pyproject.toml +++ b/integrations/ragas/pyproject.toml @@ -22,7 +22,7 @@ classifiers = [ "Programming Language :: Python :: Implementation :: CPython", "Programming Language :: Python :: Implementation :: PyPy", ] -dependencies = ["haystack-ai", "ragas>=0.1.11,<=0.1.16"] +dependencies = ["haystack-ai", "ragas>=0.2.0,<0.3.0", "langchain_openai"] [project.urls] Source = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/ragas" @@ -48,6 +48,7 @@ dependencies = [ "pytest-rerunfailures", "haystack-pydoc-tools", "pytest-asyncio", + "pydantic" ] [tool.hatch.envs.default.scripts] test = "pytest {args:tests}" diff --git a/integrations/ragas/src/haystack_integrations/components/evaluators/ragas/__init__.py b/integrations/ragas/src/haystack_integrations/components/evaluators/ragas/__init__.py index a6f420701..f572e367f 100644 --- a/integrations/ragas/src/haystack_integrations/components/evaluators/ragas/__init__.py +++ b/integrations/ragas/src/haystack_integrations/components/evaluators/ragas/__init__.py @@ -1,4 +1,3 @@ from .evaluator import RagasEvaluator -from .metrics import RagasMetric -__all__ = ("RagasEvaluator", "RagasMetric") +__all__ = ["RagasEvaluator"] diff --git a/integrations/ragas/src/haystack_integrations/components/evaluators/ragas/evaluator.py b/integrations/ragas/src/haystack_integrations/components/evaluators/ragas/evaluator.py index c44c446e6..1091cb902 100644 --- a/integrations/ragas/src/haystack_integrations/components/evaluators/ragas/evaluator.py +++ b/integrations/ragas/src/haystack_integrations/components/evaluators/ragas/evaluator.py @@ -1,171 +1,275 @@ -import json -from typing import Any, Callable, Dict, List, Optional, Union +import re +from typing import Any, Dict, List, Optional, Union, get_args, get_origin -from datasets import Dataset -from haystack import DeserializationError, component, default_from_dict, default_to_dict +from haystack import Document, component +from haystack.dataclasses import ChatMessage +from langchain_core.embeddings import Embeddings as LangchainEmbeddings # type: ignore +from langchain_core.language_models import BaseLanguageModel as LangchainLLM # type: ignore +from pydantic import ValidationError # type: ignore from ragas import evaluate # type: ignore -from ragas.evaluation import Result -from ragas.metrics.base import Metric - -from .metrics import ( - METRIC_DESCRIPTORS, - InputConverters, - OutputConverters, - RagasMetric, +from ragas.dataset_schema import ( + EvaluationDataset, + EvaluationResult, + SingleTurnSample, ) +from ragas.embeddings import BaseRagasEmbeddings +from ragas.llms import BaseRagasLLM +from ragas.metrics import Metric @component class RagasEvaluator: """ A component that uses the [Ragas framework](https://docs.ragas.io/) to evaluate - inputs against a specific metric. Supported metrics are defined by `RagasMetric`. + inputs against specified Ragas metrics. Usage example: ```python - from haystack_integrations.components.evaluators.ragas import RagasEvaluator, RagasMetric + from haystack_integrations.components.evaluators.ragas import RagasEvaluator + from ragas.metrics import ContextPrecision + from ragas.llms import LangchainLLMWrapper + from langchain_openai import ChatOpenAI + + llm = ChatOpenAI(model="gpt-4o-mini") + evaluator_llm = LangchainLLMWrapper(llm) evaluator = RagasEvaluator( - metric=RagasMetric.CONTEXT_PRECISION, + ragas_metrics=[ContextPrecision()], + evaluator_llm=evaluator_llm ) output = evaluator.run( - questions=["Which is the most popular global sport?"], - contexts=[ - [ - "Football is undoubtedly the world's most popular sport with" - "major events like the FIFA World Cup and sports personalities" - "like Ronaldo and Messi, drawing a followership of more than 4" - "billion people." - ] + query="Which is the most popular global sport?", + documents=[ + "Football is undoubtedly the world's most popular sport with" + " major events like the FIFA World Cup and sports personalities" + " like Ronaldo and Messi, drawing a followership of more than 4" + " billion people." ], - ground_truths=["Football is the most popular sport with around 4 billion" "followers worldwide"], + reference="Football is the most popular sport with around 4 billion" + " followers worldwide", ) - print(output["results"]) + + output['result'] ``` """ - # Wrapped for easy mocking. - _backend_callable: Callable - _backend_metric: Metric - def __init__( self, - metric: Union[str, RagasMetric], - metric_params: Optional[Dict[str, Any]] = None, + ragas_metrics: List[Metric], + evaluator_llm: Optional[Union[BaseRagasLLM, LangchainLLM]] = None, + evaluator_embedding: Optional[Union[BaseRagasEmbeddings, LangchainEmbeddings]] = None, ): """ - Construct a new Ragas evaluator. - - :param metric: - The metric to use for evaluation. - :param metric_params: - Parameters to pass to the metric's constructor. - Refer to the `RagasMetric` class for more details - on required parameters. + Constructs a new Ragas evaluator. + + :param ragas_metrics: A list of evaluation metrics from the [Ragas](https://docs.ragas.io/) library. + :param evaluator_llm: A language model used by metrics that require LLMs for evaluation. + :param evaluator_embedding: An embedding model used by metrics that require embeddings for evaluation. """ - self.metric = metric if isinstance(metric, RagasMetric) else RagasMetric.from_str(metric) - self.metric_params = metric_params - self.descriptor = METRIC_DESCRIPTORS[self.metric] - - self._init_backend() - self._init_metric() - - expected_inputs = self.descriptor.input_parameters - component.set_input_types(self, **expected_inputs) - - def _init_backend(self): - self._backend_callable = RagasEvaluator._invoke_evaluate - - def _init_metric(self): - if self.descriptor.init_parameters is not None: - if self.metric_params is None: - msg = f"Ragas metric '{self.metric}' expected init parameters but got none" - raise ValueError(msg) - elif not all(k in self.descriptor.init_parameters for k in self.metric_params.keys()): - msg = ( - f"Invalid init parameters for Ragas metric '{self.metric}'. " - f"Expected: {self.descriptor.init_parameters}" - ) - raise ValueError(msg) - elif self.metric_params is not None: - msg = ( - f"Invalid init parameters for Ragas metric '{self.metric}'. " - f"None expected but {self.metric_params} given" - ) - raise ValueError(msg) - metric_params = self.metric_params or {} - self._backend_metric = self.descriptor.backend(**metric_params) + self._validate_inputs(ragas_metrics, evaluator_llm, evaluator_embedding) + self.metrics = ragas_metrics + self.llm = evaluator_llm + self.embedding = evaluator_embedding + + def _validate_inputs( + self, + metrics: List[Metric], + llm: Optional[Union[BaseRagasLLM, LangchainLLM]], + embedding: Optional[Union[BaseRagasEmbeddings, LangchainEmbeddings]], + ) -> None: + """Validate input parameters. - @staticmethod - def _invoke_evaluate(dataset: Dataset, metric: Metric) -> Result: - return evaluate(dataset, [metric]) + :param metrics: List of Ragas metrics to validate + :param llm: Language model to validate + :param embedding: Embedding model to validate - @component.output_types(results=List[List[Dict[str, Any]]]) - def run(self, **inputs) -> Dict[str, Any]: + :return: None. """ - Run the Ragas evaluator on the provided inputs. - - :param inputs: - The inputs to evaluate. These are determined by the - metric being calculated. See `RagasMetric` for more - information. - :returns: - A dictionary with a single `results` entry that contains - a nested list of metric results. Each input can have one or more - results, depending on the metric. Each result is a dictionary - containing the following keys and values: - - `name` - The name of the metric. - - `score` - The score of the metric. + if not all(isinstance(metric, Metric) for metric in metrics): + error_message = "All items in ragas_metrics must be instances of Metric class." + raise TypeError(error_message) + + if llm is not None and not isinstance(llm, (BaseRagasLLM, LangchainLLM)): + error_message = f"Expected evaluator_llm to be BaseRagasLLM or LangchainLLM, got {type(llm).__name__}" + raise TypeError(error_message) + + if embedding is not None and not isinstance(embedding, (BaseRagasEmbeddings, LangchainEmbeddings)): + error_message = ( + f"Expected evaluator_embedding to be BaseRagasEmbeddings or " + f"LangchainEmbeddings, got {type(embedding).__name__}" + ) + raise TypeError(error_message) + + @component.output_types(result=EvaluationResult) + def run( + self, + query: Optional[str] = None, + response: Optional[Union[List[ChatMessage], str]] = None, + documents: Optional[List[Union[Document, str]]] = None, + reference_contexts: Optional[List[str]] = None, + multi_responses: Optional[List[str]] = None, + reference: Optional[str] = None, + rubrics: Optional[Dict[str, str]] = None, + ) -> Dict[str, Any]: + """ + Evaluates the provided query against the documents and returns the evaluation result. + + :param query: The input query from the user. + :param response: A list of ChatMessage responses (typically from a language model or agent). + :param documents: A list of Haystack Document or strings that were retrieved for the query. + :param reference_contexts: A list of reference contexts that should have been retrieved for the query. + :param multi_responses: List of multiple responses generated for the query. + :param reference: A string reference answer for the query. + :param rubrics: A dictionary of evaluation rubric, where keys represent the score + and the values represent the corresponding evaluation criteria. + :return: A dictionary containing the evaluation result. """ - InputConverters.validate_input_parameters(self.metric, self.descriptor.input_parameters, inputs) - converted_inputs: List[Dict[str, str]] = list(self.descriptor.input_converter(**inputs)) # type: ignore + processed_docs = self._process_documents(documents) + processed_response = self._process_response(response) - dataset = Dataset.from_list(converted_inputs) - results = self._backend_callable(dataset=dataset, metric=self._backend_metric) + try: + sample = SingleTurnSample( + user_input=query, + retrieved_contexts=processed_docs, + reference_contexts=reference_contexts, + response=processed_response, + multi_responses=multi_responses, + reference=reference, + rubrics=rubrics, + ) - OutputConverters.validate_outputs(results) - converted_results = [ - [result.to_dict()] for result in self.descriptor.output_converter(results, self.metric, self.metric_params) - ] + except (ValueError, ValidationError) as e: + raise self._handle_conversion_error(e) from None - return {"results": converted_results} + dataset = EvaluationDataset([sample]) - def to_dict(self) -> Dict[str, Any]: + try: + result = evaluate( + dataset=dataset, + metrics=self.metrics, + llm=self.llm, + embeddings=self.embedding, + ) + except (ValueError, ValidationError) as e: + raise self._handle_evaluation_error(e) from None + + return {"result": result} + + def _process_documents(self, documents: Union[List[Union[Document, str]], None]) -> Union[List[str], None]: + """Process and validate input documents. + + :param documents: List of Documents or strings to process + :return: List of document contents as strings or None """ - Serializes the component to a dictionary. + if documents: + first_type = type(documents[0]) + if first_type is Document: + if not all(isinstance(doc, Document) for doc in documents): + error_message = "All elements in documents list must be of type Document." + raise ValueError(error_message) + return [doc.content for doc in documents] # type: ignore[union-attr] - :returns: - Dictionary with serialized data. - :raises DeserializationError: - If the component cannot be serialized. + if first_type is str: + if not all(isinstance(doc, str) for doc in documents): + error_message = "All elements in documents list must be strings." + raise ValueError(error_message) + return documents + error_message = "Unsupported type in documents list." + raise ValueError(error_message) + return documents + + def _process_response(self, response: Optional[Union[List[ChatMessage], str]]) -> Union[str, None]: + """Process response into expected format. + + :param response: Response to process + :return: None or Processed response string """ + if isinstance(response, list): # Check if response is a list + if all(isinstance(item, ChatMessage) for item in response): + return response[0]._content[0].text + return None + elif isinstance(response, str): + return response + return response + + def _handle_conversion_error(self, error: Exception): + """Handle evaluation errors with improved messages. + + :params error: Original error + """ + if isinstance(error, ValidationError): + field_mapping = { + "user_input": "query", + "retrieved_contexts": "documents", + } + for err in error.errors(): + field = err["loc"][0] + haystack_field = field_mapping.get(field, field) + expected_type = self.run.__annotations__.get(haystack_field) + type_desc = self._get_expected_type_description(expected_type) + actual_type = type(err["input"]).__name__ + example = self._get_example_input(haystack_field) + error_message = ( + f"Validation error occured while running RagasEvaluator Component:\n" + f"The '{haystack_field}' field expected '{type_desc}', " + f"but got '{actual_type}'.\n" + f"Hint: Provide {example}" + ) + raise ValueError(error_message) + + def _handle_evaluation_error(self, error: Exception): + error_message = str(error) + columns_match = re.search(r"additional columns \[(.*?)\]", error_message) + field_mapping = { + "user_input": "query", + "retrieved_contexts": "documents", + } + if columns_match: + columns_str = columns_match.group(1) + columns = [col.strip().strip("'") for col in columns_str.split(",")] + + mapped_columns = [field_mapping.get(col, col) for col in columns] + updated_columns_str = "[" + ", ".join(f"'{col}'" for col in mapped_columns) + "]" + + # Update the list of columns in the error message + updated_error_message = error_message.replace( + columns_match.group(0), f"additional columns {updated_columns_str}" + ) + raise ValueError(updated_error_message) + + def _get_expected_type_description(self, expected_type) -> str: + """Helper method to get a description of the expected type.""" + if get_origin(expected_type) is Union: + expected_types = [getattr(t, "__name__", str(t)) for t in get_args(expected_type)] + return f"one of {', '.join(expected_types)}" + elif get_origin(expected_type) is list: + expected_item_type = get_args(expected_type)[0] + item_type_name = getattr(expected_item_type, "__name__", str(expected_item_type)) + return f"a list of {item_type_name}" + elif get_origin(expected_type) is dict: + key_type, value_type = get_args(expected_type) + key_type_name = getattr(key_type, "__name__", str(key_type)) + value_type_name = getattr(value_type, "__name__", str(value_type)) + return f"a dictionary with keys of type {key_type_name} and values of type {value_type_name}" + else: + # Handle non-generic types or unknown types gracefully + return getattr(expected_type, "__name__", str(expected_type)) - def check_serializable(obj: Any): - try: - json.dumps(obj) - return True - except (TypeError, OverflowError): - return False - - if not check_serializable(self.metric_params): - msg = "Ragas evaluator cannot serialize the metric parameters" - raise DeserializationError(msg) - - return default_to_dict( - self, - metric=self.metric, - metric_params=self.metric_params, - ) - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "RagasEvaluator": + def _get_example_input(self, field: str) -> str: """ - Deserializes the component from a dictionary. + Helper method to get an example input based on the field. - :param data: - Dictionary to deserialize from. - :returns: - Deserialized component. + :param field: Arguement used to make SingleTurnSample. + :returns: Example usage for the field. """ - return default_from_dict(cls, data) + examples = { + "query": "A string query like 'Question?'", + "documents": "[Document(content='Example content')]", + "reference_contexts": "['Example string 1', 'Example string 2']", + "response": "ChatMessage(_content='Hi', _role='assistant')", + "multi_responses": "['Response 1', 'Response 2']", + "reference": "'A reference string'", + "rubrics": "{'score1': 'high_similarity'}", + } + return examples.get(field, "An appropriate value based on the field's type") diff --git a/integrations/ragas/src/haystack_integrations/components/evaluators/ragas/metrics.py b/integrations/ragas/src/haystack_integrations/components/evaluators/ragas/metrics.py deleted file mode 100644 index 5d6ed16bc..000000000 --- a/integrations/ragas/src/haystack_integrations/components/evaluators/ragas/metrics.py +++ /dev/null @@ -1,333 +0,0 @@ -import dataclasses -import inspect -from dataclasses import dataclass -from enum import Enum -from typing import Any, Callable, Dict, Iterable, List, Optional, Type, Union - -from ragas.evaluation import Result -from ragas.metrics import ( # type: ignore - AnswerCorrectness, # type: ignore - AnswerRelevancy, # type: ignore - AnswerSimilarity, # type: ignore - AspectCritique, # type: ignore - ContextPrecision, # type: ignore - ContextRecall, # type: ignore - ContextUtilization, # type: ignore - Faithfulness, # type: ignore -) -from ragas.metrics.base import Metric - - -class RagasBaseEnum(Enum): - """ - Base functionality for a Ragas enum. - """ - - def __str__(self): - return self.value - - @classmethod - def from_str(cls, string: str) -> "RagasMetric": - """ - Create a metric type from a string. - - :param string: - The string to convert. - :returns: - The metric. - """ - enum_map = {e.value: e for e in RagasMetric} - metric = enum_map.get(string) - if metric is None: - msg = f"Unknown Ragas metric '{string}'. Supported metrics: {list(enum_map.keys())}" - raise ValueError(msg) - return metric - - -class RagasMetric(RagasBaseEnum): - """ - Metrics supported by Ragas. - """ - - #: Answer correctness.\ - #: Inputs - `questions: List[str], responses: List[str], ground_truths: List[str]`\ - #: Parameters - `weights: Tuple[float, float]` - ANSWER_CORRECTNESS = "answer_correctness" - - #: Faithfulness.\ - #: Inputs - `questions: List[str], contexts: List[List[str]], responses: List[str]` - FAITHFULNESS = "faithfulness" - - #: Answer similarity.\ - #: Inputs - `responses: List[str], ground_truths: List[str]`\ - #: Parameters - `threshold: float` - ANSWER_SIMILARITY = "answer_similarity" - - #: Context precision.\ - #: Inputs - `questions: List[str], contexts: List[List[str]], ground_truths: List[str]` - CONTEXT_PRECISION = "context_precision" - - #: Context utilization. - #: Inputs - `questions: List[str], contexts: List[List[str]], responses: List[str]`\ - CONTEXT_UTILIZATION = "context_utilization" - - #: Context recall. - #: Inputs - `questions: List[str], contexts: List[List[str]], ground_truths: List[str]`\ - CONTEXT_RECALL = "context_recall" - - #: Aspect critique. - #: Inputs - `questions: List[str], contexts: List[List[str]], responses: List[str]`\ - #: Parameters - `name: str, definition: str, strictness: int` - ASPECT_CRITIQUE = "aspect_critique" - - #: Answer relevancy.\ - #: Inputs - `questions: List[str], contexts: List[List[str]], responses: List[str]`\ - #: Parameters - `strictness: int` - ANSWER_RELEVANCY = "answer_relevancy" - - -@dataclass(frozen=True) -class MetricResult: - """ - Result of a metric evaluation. - - :param name: - The name of the metric. - :param score: - The score of the metric. - """ - - name: str - score: float - - def to_dict(self): - return dataclasses.asdict(self) - - -@dataclass(frozen=True) -class MetricDescriptor: - """ - Descriptor for a metric. - - :param metric: - The metric. - :param backend: - The associated Ragas metric class. - :param input_parameters: - Parameters accepted by the metric. This is used - to set the input types of the evaluator component. - :param input_converter: - Callable that converts input parameters to the Ragas input format. - :param output_converter: - Callable that converts the Ragas output format to our output format. - Accepts a single output parameter and returns a list of results derived from it. - :param init_parameters: - Additional parameters that are allowed to be passed to the metric class during initialization. - """ - - metric: RagasMetric - backend: Type[Metric] - input_parameters: Dict[str, Type] - input_converter: Callable[[Any], Iterable[Dict[str, str]]] - output_converter: Callable[[Result, RagasMetric, Optional[Dict[str, Any]]], List[MetricResult]] - init_parameters: Optional[List[str]] = None - - @classmethod - def new( - cls, - metric: RagasMetric, - backend: Type[Metric], - input_converter: Callable[[Any], Iterable[Dict[str, str]]], - output_converter: Optional[ - Callable[[Result, RagasMetric, Optional[Dict[str, Any]]], List[MetricResult]] - ] = None, - *, - init_parameters: Optional[List[str]] = None, - ) -> "MetricDescriptor": - input_converter_signature = inspect.signature(input_converter) - input_parameters = {} - for name, param in input_converter_signature.parameters.items(): - if name in ("cls", "self"): - continue - elif param.kind not in (inspect.Parameter.KEYWORD_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD): - continue - input_parameters[name] = param.annotation - - return cls( - metric=metric, - backend=backend, - input_parameters=input_parameters, - input_converter=input_converter, - output_converter=output_converter if output_converter is not None else OutputConverters.default, - init_parameters=init_parameters, - ) - - -class InputConverters: - """ - Converters for input parameters. - - The signature of the converter functions serves as the ground-truth of the - expected input parameters of a given metric. They are also responsible for validating - the input parameters and converting them to the format expected by Ragas. - """ - - @staticmethod - def _validate_input_elements(**kwargs) -> None: - for k, collection in kwargs.items(): - if not isinstance(collection, list): - msg = ( - f"Ragas evaluator expected input '{k}' to be a collection of type 'list', " - f"got '{type(collection).__name__}' instead" - ) - raise ValueError(msg) - elif not all(isinstance(x, str) for x in collection) and not all(isinstance(x, list) for x in collection): - msg = f"Ragas evaluator expects inputs to be of type 'str' or 'list' in '{k}'" - raise ValueError(msg) - - same_length = len({len(x) for x in kwargs.values()}) == 1 - if not same_length: - msg = f"Mismatching counts in the following inputs: {({k: len(v) for k, v in kwargs.items()})}" - raise ValueError(msg) - - @staticmethod - def validate_input_parameters( - metric: RagasMetric, - expected: Dict[str, Any], - received: Dict[str, Any], - ) -> None: - for param, _ in expected.items(): - if param not in received: - msg = f"Ragas evaluator expected input parameter '{param}' for metric '{metric}'" - raise ValueError(msg) - - @staticmethod - def question_context_response( - questions: List[str], contexts: List[List[str]], responses: List[str] - ) -> Iterable[Dict[str, Union[str, List[str]]]]: - InputConverters._validate_input_elements(questions=questions, contexts=contexts, responses=responses) - for q, c, r in zip(questions, contexts, responses): # type: ignore - yield {"question": q, "contexts": c, "answer": r} - - @staticmethod - def question_context_ground_truth( - questions: List[str], - contexts: List[List[str]], - ground_truths: List[str], - ) -> Iterable[Dict[str, Union[str, List[str]]]]: - InputConverters._validate_input_elements(questions=questions, contexts=contexts, ground_truths=ground_truths) - for q, c, gt in zip(questions, contexts, ground_truths): # type: ignore - yield {"question": q, "contexts": c, "ground_truth": gt} - - @staticmethod - def question_context( - questions: List[str], - contexts: List[List[str]], - ) -> Iterable[Dict[str, Union[str, List[str]]]]: - InputConverters._validate_input_elements(questions=questions, contexts=contexts) - for q, c in zip(questions, contexts): # type: ignore - yield {"question": q, "contexts": c} - - @staticmethod - def response_ground_truth( - responses: List[str], - ground_truths: List[str], - ) -> Iterable[Dict[str, str]]: - InputConverters._validate_input_elements(responses=responses, ground_truths=ground_truths) - for r, gt in zip(responses, ground_truths): # type: ignore - yield {"answer": r, "ground_truth": gt} - - @staticmethod - def question_response_ground_truth( - questions: List[str], - responses: List[str], - ground_truths: List[str], - ) -> Iterable[Dict[str, str]]: - InputConverters._validate_input_elements(questions=questions, ground_truths=ground_truths, responses=responses) - for q, r, gt in zip(questions, responses, ground_truths): # type: ignore - yield {"question": q, "answer": r, "ground_truth": gt} - - -class OutputConverters: - """ - Converters for results returned by Ragas. - - They are responsible for converting the results to our output format. - """ - - @staticmethod - def validate_outputs(outputs: Result) -> None: - if not isinstance(outputs, Result): - msg = f"Expected response from Ragas evaluator to be a 'Result', got '{type(outputs).__name__}'" - raise ValueError(msg) - - @staticmethod - def _extract_default_results(output: Result, metric_name: str) -> List[MetricResult]: - try: - output_scores: List[Dict[str, float]] = output.scores.to_list() - return [MetricResult(name=metric_name, score=metric_dict[metric_name]) for metric_dict in output_scores] - except KeyError as e: - msg = f"Ragas evaluator did not return an expected output for metric '{e.args[0]}'" - raise ValueError(msg) from e - - @staticmethod - def default(output: Result, metric: RagasMetric, _: Optional[Dict]) -> List[MetricResult]: - metric_name = metric.value - return OutputConverters._extract_default_results(output, metric_name) - - @staticmethod - def aspect_critique(output: Result, _: RagasMetric, metric_params: Optional[Dict[str, Any]]) -> List[MetricResult]: - if metric_params is None: - msg = "Aspect critique metric requires metric parameters" - raise ValueError(msg) - metric_name = metric_params["name"] - return OutputConverters._extract_default_results(output, metric_name) - - -METRIC_DESCRIPTORS = { - RagasMetric.ANSWER_CORRECTNESS: MetricDescriptor.new( - RagasMetric.ANSWER_CORRECTNESS, - AnswerCorrectness, - InputConverters.question_response_ground_truth, # type: ignore - init_parameters=["weights"], - ), - RagasMetric.FAITHFULNESS: MetricDescriptor.new( - RagasMetric.FAITHFULNESS, - Faithfulness, - InputConverters.question_context_response, # type: ignore - ), - RagasMetric.ANSWER_SIMILARITY: MetricDescriptor.new( - RagasMetric.ANSWER_SIMILARITY, - AnswerSimilarity, - InputConverters.response_ground_truth, # type: ignore - init_parameters=["threshold"], - ), - RagasMetric.CONTEXT_PRECISION: MetricDescriptor.new( - RagasMetric.CONTEXT_PRECISION, - ContextPrecision, - InputConverters.question_context_ground_truth, # type: ignore - ), - RagasMetric.CONTEXT_UTILIZATION: MetricDescriptor.new( - RagasMetric.CONTEXT_UTILIZATION, - ContextUtilization, - InputConverters.question_context_response, # type: ignore - ), - RagasMetric.CONTEXT_RECALL: MetricDescriptor.new( - RagasMetric.CONTEXT_RECALL, - ContextRecall, - InputConverters.question_context_ground_truth, # type: ignore - ), - RagasMetric.ASPECT_CRITIQUE: MetricDescriptor.new( - RagasMetric.ASPECT_CRITIQUE, - AspectCritique, - InputConverters.question_context_response, # type: ignore - OutputConverters.aspect_critique, - init_parameters=["name", "definition", "strictness"], - ), - RagasMetric.ANSWER_RELEVANCY: MetricDescriptor.new( - RagasMetric.ANSWER_RELEVANCY, - AnswerRelevancy, - InputConverters.question_context_response, # type: ignore - init_parameters=["strictness"], - ), -} diff --git a/integrations/ragas/tests/test_evaluator.py b/integrations/ragas/tests/test_evaluator.py index 0f847ed0b..f546d9e1e 100644 --- a/integrations/ragas/tests/test_evaluator.py +++ b/integrations/ragas/tests/test_evaluator.py @@ -1,350 +1,152 @@ -import copy -import os -from dataclasses import dataclass - import pytest -from datasets import Dataset -from haystack import DeserializationError -from ragas.evaluation import Result -from ragas.metrics.base import Metric - -from haystack_integrations.components.evaluators.ragas import RagasEvaluator, RagasMetric - -DEFAULT_QUESTIONS = [ - "Which is the most popular global sport?", - "Who created the Python language?", -] -DEFAULT_CONTEXTS = [ - [ - "The popularity of sports can be measured in various ways, including TV viewership, social media presence, number of participants, and economic impact.", - "Football is undoubtedly the world's most popular sport with major events like the FIFA World Cup and sports personalities like Ronaldo and Messi, drawing a followership of more than 4 billion people.", - ], - [ - "Python, created by Guido van Rossum in the late 1980s, is a high-level general-purpose programming language. Its design philosophy emphasizes code readability, and its language constructs aim to help programmers write clear, logical code for both small and large-scale software projects." - ], -] -DEFAULT_RESPONSES = [ - "Football is the most popular sport with around 4 billion followers worldwide", - "Python language was created by Guido van Rossum.", -] -DEFAULT_GROUND_TRUTHS = [ - "Football (Soccer) is the most popular sport in the world with almost 4 billion fans around the world.", - "Guido van Rossum is the creator of the Python programming language.", -] - - -@dataclass(frozen=True) -class Unserializable: - something: str - - -# Only returns results for the passed metrics. -class MockBackend: - def __init__(self, metric: RagasMetric) -> None: - self.metric = metric - - def evaluate(self, _, metric: Metric, **kwargs): - output_map = { - RagasMetric.ANSWER_CORRECTNESS: Result(scores=Dataset.from_list([{"answer_correctness": 0.5}])), - RagasMetric.FAITHFULNESS: Result(scores=Dataset.from_list([{"faithfulness": 1.0}])), - RagasMetric.ANSWER_SIMILARITY: Result(scores=Dataset.from_list([{"answer_similarity": 1.0}])), - RagasMetric.CONTEXT_PRECISION: Result(scores=Dataset.from_list([{"context_precision": 0.5}])), - RagasMetric.CONTEXT_UTILIZATION: Result(scores=Dataset.from_list([{"context_utilization": 1.0}])), - RagasMetric.CONTEXT_RECALL: Result(scores=Dataset.from_list([{"context_recall": 0.9}])), - RagasMetric.ASPECT_CRITIQUE: Result(scores=Dataset.from_list([{"harmfulness": 1.0}])), - RagasMetric.ANSWER_RELEVANCY: Result(scores=Dataset.from_list([{"answer_relevancy": 0.4}])), - } - assert isinstance(metric, Metric) - return output_map[self.metric] - - -@pytest.mark.parametrize( - "metric, init_params", - [ - (RagasMetric.ANSWER_CORRECTNESS, {"weights": [0.5, 0.5]}), - (RagasMetric.FAITHFULNESS, None), - (RagasMetric.ANSWER_SIMILARITY, {"threshold": 0.5}), - (RagasMetric.CONTEXT_PRECISION, None), - (RagasMetric.CONTEXT_UTILIZATION, None), - (RagasMetric.CONTEXT_RECALL, None), - ( - RagasMetric.ASPECT_CRITIQUE, - { - "name": "harmfulness", - "definition": "Does the submission causes or has the potential to " - "cause harm to individuals, groups, or society at " - "large?", - }, - ), - (RagasMetric.ANSWER_RELEVANCY, {"strictness": 2}), - ], -) -def test_evaluator_valid_metric_init_params(metric, init_params): - eval = RagasEvaluator( - metric, - metric_params=init_params, +from unittest import mock +from unittest.mock import MagicMock, patch +from ragas.metrics import Metric, Faithfulness +from ragas.llms import BaseRagasLLM +from ragas.embeddings import BaseRagasEmbeddings +from ragas.dataset_schema import EvaluationResult +from haystack import Document +from haystack_integrations.components.evaluators.ragas import RagasEvaluator + + +# Fixtures +@pytest.fixture +def mock_run(): + """Fixture to mock the 'run' method of RagasEvaluator.""" + with mock.patch.object(RagasEvaluator, 'run') as mock_method: + yield mock_method + + +@pytest.fixture +def ragas_evaluator(): + """Fixture to create a valid RagasEvaluator instance.""" + valid_metrics = [MagicMock(spec=Metric) for _ in range(3)] + valid_llm = MagicMock(spec=BaseRagasLLM) + valid_embedding = MagicMock(spec=BaseRagasEmbeddings) + return RagasEvaluator( + ragas_metrics=valid_metrics, + evaluator_llm=valid_llm, + evaluator_embedding=valid_embedding, ) - assert eval.metric_params == init_params - - msg = f"Invalid init parameters for Ragas metric '{metric}'. " - with pytest.raises(ValueError, match=msg): - RagasEvaluator( - metric, - metric_params={"invalid_param": "invalid_value"}, - ) -@pytest.mark.parametrize( - "metric", - [ - RagasMetric.ANSWER_CORRECTNESS, - RagasMetric.ANSWER_SIMILARITY, - RagasMetric.ASPECT_CRITIQUE, - RagasMetric.ANSWER_RELEVANCY, - ], -) -def test_evaluator_fails_with_no_metric_init_params(metric): - msg = f"Ragas metric '{metric}' expected init parameters but got none" - with pytest.raises(ValueError, match=msg): - RagasEvaluator( - metric, - metric_params=None, - ) +# Tests +def test_successful_initialization(ragas_evaluator): + """Test RagasEvaluator initializes correctly with valid inputs.""" + assert len(ragas_evaluator.metrics) == 3 + assert isinstance(ragas_evaluator.llm, BaseRagasLLM) + assert isinstance(ragas_evaluator.embedding, BaseRagasEmbeddings) -def test_evaluator_serde(): - init_params = { - "metric": RagasMetric.ASPECT_CRITIQUE, - "metric_params": { - "name": "harmfulness", - "definition": "Does the submission causes or has the potential to " - "cause harm to individuals, groups, or society at " - "large?", - }, - } - eval = RagasEvaluator(**init_params) - serde_data = eval.to_dict() - new_eval = RagasEvaluator.from_dict(serde_data) +def test_invalid_metrics(): + """Test RagasEvaluator raises TypeError for invalid metrics.""" + invalid_metric = "not_a_metric" - assert eval.metric == new_eval.metric - assert eval.metric_params == new_eval.metric_params + with pytest.raises(TypeError, match="All items in ragas_metrics must be instances of Metric class."): + RagasEvaluator(ragas_metrics=[invalid_metric]) - with pytest.raises(DeserializationError, match=r"cannot serialize the metric parameters"): - init_params3 = copy.deepcopy(init_params) - init_params3["metric_params"]["name"] = Unserializable("") - eval = RagasEvaluator(**init_params3) - eval.to_dict() +def test_invalid_llm(): + """Test RagasEvaluator raises TypeError for invalid evaluator_llm.""" + valid_metric = MagicMock(spec=Metric) + invalid_llm = "not_a_llm" -@pytest.mark.parametrize( - "current_metric, inputs, params", - [ - ( - RagasMetric.ANSWER_CORRECTNESS, - {"questions": [], "responses": [], "ground_truths": []}, - {"weights": [0.5, 0.5]}, - ), - (RagasMetric.FAITHFULNESS, {"questions": [], "contexts": [], "responses": []}, None), - (RagasMetric.ANSWER_SIMILARITY, {"responses": [], "ground_truths": []}, {"threshold": 0.5}), - (RagasMetric.CONTEXT_PRECISION, {"questions": [], "contexts": [], "ground_truths": []}, None), - (RagasMetric.CONTEXT_UTILIZATION, {"questions": [], "contexts": [], "responses": []}, None), - (RagasMetric.CONTEXT_RECALL, {"questions": [], "contexts": [], "ground_truths": []}, None), - ( - RagasMetric.ASPECT_CRITIQUE, - {"questions": [], "contexts": [], "responses": []}, - { - "name": "harmfulness", - "definition": "Does the submission causes or has the potential to " - "cause harm to individuals, groups, or society at " - "large?", - }, - ), - (RagasMetric.ANSWER_RELEVANCY, {"questions": [], "contexts": [], "responses": []}, {"strictness": 2}), - ], -) -def test_evaluator_valid_inputs(current_metric, inputs, params): - init_params = { - "metric": current_metric, - "metric_params": params, - } - eval = RagasEvaluator(**init_params) - eval._backend_callable = lambda dataset, metric: MockBackend(current_metric).evaluate(dataset, metric) - output = eval.run(**inputs) + with pytest.raises(TypeError, match="Expected evaluator_llm to be BaseRagasLLM or LangchainLLM"): + RagasEvaluator(ragas_metrics=[valid_metric], evaluator_llm=invalid_llm) -@pytest.mark.parametrize( - "current_metric, inputs, error_string, params", - [ - ( - RagasMetric.FAITHFULNESS, - {"questions": [1], "contexts": [2], "responses": [3]}, - "expects inputs to be of type 'str'", - None, - ), - ( - RagasMetric.ANSWER_RELEVANCY, - {"questions": [""], "responses": [], "contexts": []}, - "Mismatching counts ", - {"strictness": 2}, - ), - (RagasMetric.ANSWER_RELEVANCY, {"responses": []}, "expected input parameter ", {"strictness": 2}), - ], -) -def test_evaluator_invalid_inputs(current_metric, inputs, error_string, params): - with pytest.raises(ValueError, match=error_string): - init_params = { - "metric": current_metric, - "metric_params": params, - } - eval = RagasEvaluator(**init_params) - eval._backend_callable = lambda dataset, metric: MockBackend(current_metric).evaluate(dataset, metric) - output = eval.run(**inputs) +def test_invalid_embedding(): + """Test RagasEvaluator raises TypeError for invalid evaluator_embedding.""" + valid_metric = MagicMock(spec=Metric) + invalid_embedding = "not_an_embedding" + with pytest.raises( + TypeError, match="Expected evaluator_embedding to be BaseRagasEmbeddings or LangchainEmbeddings" + ): + RagasEvaluator(ragas_metrics=[valid_metric], evaluator_embedding=invalid_embedding) -# This test validates the expected outputs of the evaluator. -# Each output is parameterized as a list of tuples, where each tuple is (name, score). -@pytest.mark.parametrize( - "current_metric, inputs, expected_outputs, metric_params", - [ - ( - RagasMetric.ANSWER_CORRECTNESS, - {"questions": ["q1"], "responses": ["r1"], "ground_truths": ["gt1"]}, - [[(None, 0.5)]], - {"weights": [0.5, 0.5]}, - ), - ( - RagasMetric.FAITHFULNESS, - {"questions": ["q2"], "contexts": [["c2"]], "responses": ["r2"]}, - [[(None, 1.0)]], - None, - ), - ( - RagasMetric.ANSWER_SIMILARITY, - {"responses": ["r3"], "ground_truths": ["gt3"]}, - [[(None, 1.0)]], - {"threshold": 0.5}, - ), - ( - RagasMetric.CONTEXT_PRECISION, - {"questions": ["q4"], "contexts": [["c4"]], "ground_truths": ["gt44"]}, - [[(None, 0.5)]], - None, - ), - ( - RagasMetric.CONTEXT_UTILIZATION, - {"questions": ["q5"], "contexts": [["c5"]], "responses": ["r5"]}, - [[(None, 1.0)]], - None, - ), - ( - RagasMetric.CONTEXT_RECALL, - {"questions": ["q6"], "contexts": [["c6"]], "ground_truths": ["gt6"]}, - [[(None, 0.9)]], - None, - ), - ( - RagasMetric.ASPECT_CRITIQUE, - {"questions": ["q7"], "contexts": [["c7"]], "responses": ["r7"]}, - [[("harmfulness", 1.0)]], - { - "name": "harmfulness", - "definition": "Does the submission causes or has the potential to " - "cause harm to individuals, groups, or society at " - "large?", - }, - ), - ( - RagasMetric.ANSWER_RELEVANCY, - {"questions": ["q9"], "contexts": [["c9"]], "responses": ["r9"]}, - [[(None, 0.4)]], - {"strictness": 2}, - ), - ], -) -def test_evaluator_outputs(current_metric, inputs, expected_outputs, metric_params): - init_params = { - "metric": current_metric, - "metric_params": metric_params, - } - eval = RagasEvaluator(**init_params) - eval._backend_callable = lambda dataset, metric: MockBackend(current_metric).evaluate(dataset, metric) - results = eval.run(**inputs)["results"] - - assert type(results) == type(expected_outputs) - assert len(results) == len(expected_outputs) - for r, o in zip(results, expected_outputs): - assert len(r) == len(o) +def test_initializer_allows_optional_llm_and_embeddings(): + """Test RagasEvaluator initializes correctly with None for optional parameters.""" + valid_metric = MagicMock(spec=Metric) - expected = {(name if name is not None else str(current_metric), score) for name, score in o} - got = {(x["name"], x["score"]) for x in r} - assert got == expected + evaluator = RagasEvaluator( + ragas_metrics=[valid_metric], + evaluator_llm=None, + evaluator_embedding=None, + ) + assert evaluator.metrics == [valid_metric] + assert evaluator.llm is None + assert evaluator.embedding is None -# This integration test validates the evaluator by running it against the -# OpenAI API. It is parameterized by the metric, the inputs to the evaluator -# and the metric parameters. -@pytest.mark.asyncio -@pytest.mark.skipif("OPENAI_API_KEY" not in os.environ, reason="OPENAI_API_KEY not set") @pytest.mark.parametrize( - "metric, inputs, metric_params", + "invalid_input,field_name,error_message", [ - ( - RagasMetric.ANSWER_CORRECTNESS, - {"questions": DEFAULT_QUESTIONS, "responses": DEFAULT_RESPONSES, "ground_truths": DEFAULT_GROUND_TRUTHS}, - {"weights": [0.5, 0.5]}, - ), - ( - RagasMetric.FAITHFULNESS, - {"questions": DEFAULT_QUESTIONS, "contexts": DEFAULT_CONTEXTS, "responses": DEFAULT_RESPONSES}, - None, - ), - ( - RagasMetric.ANSWER_SIMILARITY, - {"responses": DEFAULT_QUESTIONS, "ground_truths": DEFAULT_GROUND_TRUTHS}, - {"threshold": 0.5}, - ), - ( - RagasMetric.CONTEXT_PRECISION, - {"questions": DEFAULT_QUESTIONS, "contexts": DEFAULT_CONTEXTS, "ground_truths": DEFAULT_GROUND_TRUTHS}, - None, - ), - ( - RagasMetric.CONTEXT_UTILIZATION, - {"questions": DEFAULT_QUESTIONS, "contexts": DEFAULT_CONTEXTS, "responses": DEFAULT_RESPONSES}, - None, - ), - ( - RagasMetric.CONTEXT_RECALL, - {"questions": DEFAULT_QUESTIONS, "contexts": DEFAULT_CONTEXTS, "ground_truths": DEFAULT_GROUND_TRUTHS}, - None, - ), - ( - RagasMetric.ASPECT_CRITIQUE, - {"questions": DEFAULT_QUESTIONS, "contexts": DEFAULT_CONTEXTS, "responses": DEFAULT_RESPONSES}, - { - "name": "harmfulness", - "definition": "Does the submission causes or has the potential to " - "cause harm to individuals, groups, or society at " - "large?", - }, - ), - ( - RagasMetric.ANSWER_RELEVANCY, - {"questions": DEFAULT_QUESTIONS, "contexts": DEFAULT_CONTEXTS, "responses": DEFAULT_RESPONSES}, - {"strictness": 2}, - ), + (["Invalid query type"], "query", "'query' field expected"), + ([123, ["Invalid document"]], "documents", "Unsupported type in documents list"), + (["score_1"], "rubrics", "'rubrics' field expected"), ], ) -def test_integration_run(metric, inputs, metric_params): - init_params = { - "metric": metric, - "metric_params": metric_params, - } - eval = RagasEvaluator(**init_params) - output = eval.run(**inputs) +def test_run_invalid_inputs(invalid_input, field_name, error_message): + """Test RagasEvaluator raises ValueError for invalid input types.""" + evaluator = RagasEvaluator(ragas_metrics=[Faithfulness()]) + query = "Which is the most popular global sport?" + documents = ["Football is the most popular sport."] + response = "Football is the most popular sport in the world" + + with pytest.raises(ValueError) as exc_info: + if field_name == "query": + evaluator.run(query=invalid_input, documents=documents, response=response) + elif field_name == "documents": + evaluator.run(query=query, documents=invalid_input, response=response) + elif field_name == "rubrics": + evaluator.run(query=query, rubrics=invalid_input, documents=documents, response=response) + + assert error_message in str(exc_info.value) + + +def test_missing_columns_in_dataset(): + """Test if RagasEvaluator raises a ValueError when required columns are missing for a specific metric.""" + evaluator = RagasEvaluator(ragas_metrics=[Faithfulness()]) + query = "Which is the most popular global sport?" + reference = "Football is the most popular sport with around 4 billion followers worldwide" + response = "Football is the most popular sport in the world" + + with pytest.raises(ValueError) as exc_info: + evaluator.run(query=query, reference=reference, response=response) + + assert "faithfulness" in str(exc_info.value) + assert "documents" in str(exc_info.value) + + +def test_run_valid_input(mock_run): + """Test RagasEvaluator runs successfully with valid input.""" + mock_run.return_value = {"result": {"score": MagicMock(), "details": MagicMock(spec=EvaluationResult)}} + evaluator = RagasEvaluator(ragas_metrics=[MagicMock(Metric)]) + + query = "Which is the most popular global sport?" + response = "Football is the most popular sport in the world" + documents = [ + Document(content="Football is the world's most popular sport."), + Document(content="Football has over 4 billion followers."), + ] + reference_contexts = ["Football is a globally popular sport."] + multi_responses = ["Football is considered the most popular sport."] + reference = "Football is the most popular sport with around 4 billion followers worldwide" + rubrics = {"accuracy": "high", "relevance": "high"} + + output = evaluator.run( + query=query, + response=response, + documents=documents, + reference_contexts=reference_contexts, + multi_responses=multi_responses, + reference=reference, + rubrics=rubrics, + ) - assert isinstance(output, dict) - assert len(output) == 1 - assert "results" in output - assert len(output["results"]) == len(next(iter(inputs.values()))) + assert "result" in output + assert isinstance(output["result"], dict) + assert "score" in output["result"] + assert isinstance(output["result"]["details"], EvaluationResult) diff --git a/integrations/ragas/tests/test_metrics.py b/integrations/ragas/tests/test_metrics.py deleted file mode 100644 index 7447689fb..000000000 --- a/integrations/ragas/tests/test_metrics.py +++ /dev/null @@ -1,11 +0,0 @@ -import pytest - -from haystack_integrations.components.evaluators.ragas import RagasMetric - - -def test_ragas_metric(): - for e in RagasMetric: - assert e == RagasMetric.from_str(e.value) - - with pytest.raises(ValueError, match="Unknown Ragas metric"): - RagasMetric.from_str("smugness") From 40cdbcef8c3d3161b99bce0564c777e9d67d5455 Mon Sep 17 00:00:00 2001 From: HaystackBot Date: Tue, 28 Jan 2025 12:30:54 +0000 Subject: [PATCH 219/229] Update the changelog --- integrations/ragas/CHANGELOG.md | 49 ++++++++++++++++++++++++++++----- 1 file changed, 42 insertions(+), 7 deletions(-) diff --git a/integrations/ragas/CHANGELOG.md b/integrations/ragas/CHANGELOG.md index 94946bddc..4ef3438fb 100644 --- a/integrations/ragas/CHANGELOG.md +++ b/integrations/ragas/CHANGELOG.md @@ -1,5 +1,20 @@ # Changelog +## [integrations/ragas-v2.0.0] - 2025-01-28 + +### 🚀 Features + +- Update Ragas integration (#1312) + +### ⚙️ CI + +- Adopt uv as installer (#1142) + +### 🧹 Chores + +- Update ruff linting scripts and settings (#1105) + + ## [integrations/ragas-v1.0.1] - 2024-09-11 ### 🐛 Bug Fixes @@ -10,48 +25,68 @@ - Do not retry tests in `hatch run test` command (#954) + ## [integrations/ragas-v1.0.0] - 2024-07-24 -### ⚙️ Miscellaneous Tasks +### ⚙️ CI - Retry tests to reduce flakyness (#836) + +### 🧹 Chores + - Update ruff invocation to include check parameter (#853) - Ragas - remove context relevancy metric (#917) +### 🌀 Miscellaneous + +- Chore: add license classifiers (#680) +- Chore: change the pydoc renderer class (#718) +- Ci: install `pytest-rerunfailures` where needed; add retry config to `test-cov` script (#845) + ## [integrations/ragas-v0.2.0] - 2024-04-23 -## [integrations/ragas-v0.1.3] - 2024-04-09 +### 🌀 Miscellaneous -### 🐛 Bug Fixes +- Unpin ragas (#677) -- Fix haystack-ai pin (#649) +## [integrations/ragas-v0.1.3] - 2024-04-09 +### 🐛 Bug Fixes +- Fix `haystack-ai` pins (#649) ### 📚 Documentation - Disable-class-def (#556) +### 🌀 Miscellaneous + +- Make tests show coverage (#566) +- Remove references to Python 3.7 (#601) + ## [integrations/ragas-v0.1.2] - 2024-03-08 ### 📚 Documentation - Update `ragas-haystack` docstrings (#529) +### 🌀 Miscellaneous + +- [RAGAS] fix: Metric parameter validation and metric descriptors (#555) + ## [integrations/ragas-v0.1.1] - 2024-02-23 ### 🐛 Bug Fixes - Fix order of API docs (#447) -This PR will also push the docs to Readme - ### 📚 Documentation - Update category slug (#442) -### Build +### 🌀 Miscellaneous +- Add Ragas integration (#404) - Pin `ragas` dependency to `0.1.1` (#476) From b666b5a7264927fe33ff1b69cf62d11e6a480743 Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Tue, 28 Jan 2025 13:32:48 +0100 Subject: [PATCH 220/229] refactor: Migrate Cohere to V2 (#1321) * Migrate Cohere to V2 --------- Co-authored-by: Stefano Fiorucci Co-authored-by: David S. Batista --- .../embedders/cohere/document_embedder.py | 21 +++++-- .../embedders/cohere/embedding_types.py | 37 +++++++++++ .../embedders/cohere/text_embedder.py | 26 ++++++-- .../components/embedders/cohere/utils.py | 61 ++++++++++++++++--- .../components/rankers/cohere/ranker.py | 18 +++++- .../cohere/tests/test_cohere_ranker.py | 11 +++- .../cohere/tests/test_document_embedder.py | 8 ++- .../cohere/tests/test_text_embedder.py | 5 ++ 8 files changed, 163 insertions(+), 24 deletions(-) create mode 100644 integrations/cohere/src/haystack_integrations/components/embedders/cohere/embedding_types.py diff --git a/integrations/cohere/src/haystack_integrations/components/embedders/cohere/document_embedder.py b/integrations/cohere/src/haystack_integrations/components/embedders/cohere/document_embedder.py index d311662fe..cbb68a8e1 100644 --- a/integrations/cohere/src/haystack_integrations/components/embedders/cohere/document_embedder.py +++ b/integrations/cohere/src/haystack_integrations/components/embedders/cohere/document_embedder.py @@ -7,7 +7,8 @@ from haystack import Document, component, default_from_dict, default_to_dict from haystack.utils import Secret, deserialize_secrets_inplace -from cohere import AsyncClient, Client +from cohere import AsyncClientV2, ClientV2 +from haystack_integrations.components.embedders.cohere.embedding_types import EmbeddingTypes from haystack_integrations.components.embedders.cohere.utils import get_async_response, get_response @@ -47,6 +48,7 @@ def __init__( progress_bar: bool = True, meta_fields_to_embed: Optional[List[str]] = None, embedding_separator: str = "\n", + embedding_type: Optional[EmbeddingTypes] = None, ): """ :param api_key: the Cohere API key. @@ -72,6 +74,8 @@ def __init__( to keep the logs clean. :param meta_fields_to_embed: list of meta fields that should be embedded along with the Document text. :param embedding_separator: separator used to concatenate the meta fields to the Document text. + :param embedding_type: the type of embeddings to return. Defaults to float embeddings. + Note that int8, uint8, binary, and ubinary are only valid for v3 models. """ self.api_key = api_key @@ -85,6 +89,7 @@ def __init__( self.progress_bar = progress_bar self.meta_fields_to_embed = meta_fields_to_embed or [] self.embedding_separator = embedding_separator + self.embedding_type = embedding_type or EmbeddingTypes.FLOAT def to_dict(self) -> Dict[str, Any]: """ @@ -106,6 +111,7 @@ def to_dict(self) -> Dict[str, Any]: progress_bar=self.progress_bar, meta_fields_to_embed=self.meta_fields_to_embed, embedding_separator=self.embedding_separator, + embedding_type=self.embedding_type.value, ) @classmethod @@ -120,6 +126,10 @@ def from_dict(cls, data: Dict[str, Any]) -> "CohereDocumentEmbedder": """ init_params = data.get("init_parameters", {}) deserialize_secrets_inplace(init_params, ["api_key"]) + + # Convert embedding_type string to EmbeddingTypes enum value + init_params["embedding_type"] = EmbeddingTypes.from_str(init_params["embedding_type"]) + return default_from_dict(cls, data) def _prepare_texts_to_embed(self, documents: List[Document]) -> List[str]: @@ -163,17 +173,19 @@ def run(self, documents: List[Document]): assert api_key is not None if self.use_async_client: - cohere_client = AsyncClient( + cohere_client = AsyncClientV2( api_key, base_url=self.api_base_url, timeout=self.timeout, client_name="haystack", ) all_embeddings, metadata = asyncio.run( - get_async_response(cohere_client, texts_to_embed, self.model, self.input_type, self.truncate) + get_async_response( + cohere_client, texts_to_embed, self.model, self.input_type, self.truncate, self.embedding_type + ) ) else: - cohere_client = Client( + cohere_client = ClientV2( api_key, base_url=self.api_base_url, timeout=self.timeout, @@ -187,6 +199,7 @@ def run(self, documents: List[Document]): self.truncate, self.batch_size, self.progress_bar, + self.embedding_type, ) for doc, embeddings in zip(documents, all_embeddings): diff --git a/integrations/cohere/src/haystack_integrations/components/embedders/cohere/embedding_types.py b/integrations/cohere/src/haystack_integrations/components/embedders/cohere/embedding_types.py new file mode 100644 index 000000000..2f11c02cb --- /dev/null +++ b/integrations/cohere/src/haystack_integrations/components/embedders/cohere/embedding_types.py @@ -0,0 +1,37 @@ +# SPDX-FileCopyrightText: 2023-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 +from enum import Enum + + +class EmbeddingTypes(Enum): + """ + Supported types for Cohere embeddings. + + FLOAT: Default float embeddings. Valid for all models. + INT8: Signed int8 embeddings. Valid for only v3 models. + UINT8: Unsigned int8 embeddings. Valid for only v3 models. + BINARY: Signed binary embeddings. Valid for only v3 models. + UBINARY: Unsigned binary embeddings. Valid for only v3 models. + """ + + FLOAT = "float" + INT8 = "int8" + UINT8 = "uint8" + BINARY = "binary" + UBINARY = "ubinary" + + def __str__(self): + return self.value + + @staticmethod + def from_str(string: str) -> "EmbeddingTypes": + """ + Convert a string to an EmbeddingTypes enum. + """ + enum_map = {e.value: e for e in EmbeddingTypes} + embedding_type = enum_map.get(string.lower()) + if embedding_type is None: + msg = f"Unknown embedding type '{string}'. Supported types are: {list(enum_map.keys())}" + raise ValueError(msg) + return embedding_type diff --git a/integrations/cohere/src/haystack_integrations/components/embedders/cohere/text_embedder.py b/integrations/cohere/src/haystack_integrations/components/embedders/cohere/text_embedder.py index c1e9bd613..fc7ff8cd2 100644 --- a/integrations/cohere/src/haystack_integrations/components/embedders/cohere/text_embedder.py +++ b/integrations/cohere/src/haystack_integrations/components/embedders/cohere/text_embedder.py @@ -2,12 +2,13 @@ # # SPDX-License-Identifier: Apache-2.0 import asyncio -from typing import Any, Dict, List +from typing import Any, Dict, List, Optional from haystack import component, default_from_dict, default_to_dict from haystack.utils import Secret, deserialize_secrets_inplace -from cohere import AsyncClient, Client +from cohere import AsyncClientV2, ClientV2 +from haystack_integrations.components.embedders.cohere.embedding_types import EmbeddingTypes from haystack_integrations.components.embedders.cohere.utils import get_async_response, get_response @@ -40,6 +41,7 @@ def __init__( truncate: str = "END", use_async_client: bool = False, timeout: int = 120, + embedding_type: Optional[EmbeddingTypes] = None, ): """ :param api_key: the Cohere API key. @@ -60,6 +62,8 @@ def __init__( :param use_async_client: flag to select the AsyncClient. It is recommended to use AsyncClient for applications with many concurrent calls. :param timeout: request timeout in seconds. + :param embedding_type: the type of embeddings to return. Defaults to float embeddings. + Note that int8, uint8, binary, and ubinary are only valid for v3 models. """ self.api_key = api_key @@ -69,6 +73,7 @@ def __init__( self.truncate = truncate self.use_async_client = use_async_client self.timeout = timeout + self.embedding_type = embedding_type or EmbeddingTypes.FLOAT def to_dict(self) -> Dict[str, Any]: """ @@ -86,6 +91,7 @@ def to_dict(self) -> Dict[str, Any]: truncate=self.truncate, use_async_client=self.use_async_client, timeout=self.timeout, + embedding_type=self.embedding_type.value, ) @classmethod @@ -100,6 +106,10 @@ def from_dict(cls, data: Dict[str, Any]) -> "CohereTextEmbedder": """ init_params = data.get("init_parameters", {}) deserialize_secrets_inplace(init_params, ["api_key"]) + + # Convert embedding_type string to EmbeddingTypes enum value + init_params["embedding_type"] = EmbeddingTypes.from_str(init_params["embedding_type"]) + return default_from_dict(cls, data) @component.output_types(embedding=List[float], meta=Dict[str, Any]) @@ -125,22 +135,26 @@ def run(self, text: str): assert api_key is not None if self.use_async_client: - cohere_client = AsyncClient( + cohere_client = AsyncClientV2( api_key, base_url=self.api_base_url, timeout=self.timeout, client_name="haystack", ) embedding, metadata = asyncio.run( - get_async_response(cohere_client, [text], self.model, self.input_type, self.truncate) + get_async_response( + cohere_client, [text], self.model, self.input_type, self.truncate, self.embedding_type + ) ) else: - cohere_client = Client( + cohere_client = ClientV2( api_key, base_url=self.api_base_url, timeout=self.timeout, client_name="haystack", ) - embedding, metadata = get_response(cohere_client, [text], self.model, self.input_type, self.truncate) + embedding, metadata = get_response( + cohere_client, [text], self.model, self.input_type, self.truncate, embedding_type=self.embedding_type + ) return {"embedding": embedding[0], "meta": metadata} diff --git a/integrations/cohere/src/haystack_integrations/components/embedders/cohere/utils.py b/integrations/cohere/src/haystack_integrations/components/embedders/cohere/utils.py index a5c20cb35..951938143 100644 --- a/integrations/cohere/src/haystack_integrations/components/embedders/cohere/utils.py +++ b/integrations/cohere/src/haystack_integrations/components/embedders/cohere/utils.py @@ -1,14 +1,22 @@ # SPDX-FileCopyrightText: 2023-present deepset GmbH # # SPDX-License-Identifier: Apache-2.0 -from typing import Any, Dict, List, Tuple +from typing import Any, Dict, List, Optional, Tuple from tqdm import tqdm -from cohere import AsyncClient, Client +from cohere import AsyncClientV2, ClientV2 +from haystack_integrations.components.embedders.cohere.embedding_types import EmbeddingTypes -async def get_async_response(cohere_async_client: AsyncClient, texts: List[str], model_name, input_type, truncate): +async def get_async_response( + cohere_async_client: AsyncClientV2, + texts: List[str], + model_name, + input_type, + truncate, + embedding_type: Optional[EmbeddingTypes] = None, +): """Embeds a list of texts asynchronously using the Cohere API. :param cohere_async_client: the Cohere `AsyncClient` @@ -17,6 +25,7 @@ async def get_async_response(cohere_async_client: AsyncClient, texts: List[str], :param input_type: one of "classification", "clustering", "search_document", "search_query". The type of input text provided to embed. :param truncate: one of "NONE", "START", "END". How the API handles text longer than the maximum token length. + :param embedding_type: the type of embeddings to return. Defaults to float embeddings. :returns: A tuple of the embeddings and metadata. @@ -25,17 +34,36 @@ async def get_async_response(cohere_async_client: AsyncClient, texts: List[str], all_embeddings: List[List[float]] = [] metadata: Dict[str, Any] = {} - response = await cohere_async_client.embed(texts=texts, model=model_name, input_type=input_type, truncate=truncate) + embedding_type = embedding_type or EmbeddingTypes.FLOAT + response = await cohere_async_client.embed( + texts=texts, + model=model_name, + input_type=input_type, + truncate=truncate, + embedding_types=[embedding_type.value], + ) if response.meta is not None: metadata = response.meta - for emb in response.embeddings: - all_embeddings.append(emb) + for emb_tuple in response.embeddings: + # emb_tuple[0] is a str denoting the embedding type (e.g. "float", "int8", etc.) + if emb_tuple[1] is not None: + # ok we have embeddings for this type, let's take all + # the embeddings (a list of embeddings) and break the loop + all_embeddings.extend(emb_tuple[1]) + break return all_embeddings, metadata def get_response( - cohere_client: Client, texts: List[str], model_name, input_type, truncate, batch_size=32, progress_bar=False + cohere_client: ClientV2, + texts: List[str], + model_name, + input_type, + truncate, + batch_size=32, + progress_bar=False, + embedding_type: Optional[EmbeddingTypes] = None, ) -> Tuple[List[List[float]], Dict[str, Any]]: """Embeds a list of texts using the Cohere API. @@ -47,6 +75,7 @@ def get_response( :param truncate: one of "NONE", "START", "END". How the API handles text longer than the maximum token length. :param batch_size: the batch size to use :param progress_bar: if `True`, show a progress bar + :param embedding_type: the type of embeddings to return. Defaults to float embeddings. :returns: A tuple of the embeddings and metadata. @@ -55,6 +84,7 @@ def get_response( all_embeddings: List[List[float]] = [] metadata: Dict[str, Any] = {} + embedding_type = embedding_type or EmbeddingTypes.FLOAT for i in tqdm( range(0, len(texts), batch_size), @@ -62,9 +92,20 @@ def get_response( desc="Calculating embeddings", ): batch = texts[i : i + batch_size] - response = cohere_client.embed(texts=batch, model=model_name, input_type=input_type, truncate=truncate) - for emb in response.embeddings: - all_embeddings.append(emb) + response = cohere_client.embed( + texts=batch, + model=model_name, + input_type=input_type, + truncate=truncate, + embedding_types=[embedding_type.value], + ) + ## response.embeddings always returns 5 tuples, one tuple per embedding type + ## let's take first non None tuple as that's the one we want + for emb_tuple in response.embeddings: + # emb_tuple[0] is a str denoting the embedding type (e.g. "float", "int8", etc.) + if emb_tuple[1] is not None: + # ok we have embeddings for this type, let's take all the embeddings (a list of embeddings) + all_embeddings.extend(emb_tuple[1]) if response.meta is not None: metadata = response.meta diff --git a/integrations/cohere/src/haystack_integrations/components/rankers/cohere/ranker.py b/integrations/cohere/src/haystack_integrations/components/rankers/cohere/ranker.py index 7da823bbc..2c3060cb9 100644 --- a/integrations/cohere/src/haystack_integrations/components/rankers/cohere/ranker.py +++ b/integrations/cohere/src/haystack_integrations/components/rankers/cohere/ranker.py @@ -40,6 +40,7 @@ def __init__( max_chunks_per_doc: Optional[int] = None, meta_fields_to_embed: Optional[List[str]] = None, meta_data_separator: str = "\n", + max_tokens_per_doc: int = 4096, ): """ Creates an instance of the 'CohereRanker'. @@ -57,6 +58,7 @@ def __init__( with the document content for reranking. :param meta_data_separator: Separator used to concatenate the meta fields to the Document content. + :param max_tokens_per_doc: The maximum number of tokens to embed for each document defaults to 4096. """ self.model_name = model self.api_key = api_key @@ -65,7 +67,18 @@ def __init__( self.max_chunks_per_doc = max_chunks_per_doc self.meta_fields_to_embed = meta_fields_to_embed or [] self.meta_data_separator = meta_data_separator - self._cohere_client = cohere.Client( + self.max_tokens_per_doc = max_tokens_per_doc + if max_chunks_per_doc is not None: + # Note: max_chunks_per_doc is currently not supported by the Cohere V2 API + # See: https://docs.cohere.com/reference/rerank + import warnings + + warnings.warn( + "The max_chunks_per_doc parameter currently has no effect as it is not supported by the Cohere V2 API.", + UserWarning, + stacklevel=2, + ) + self._cohere_client = cohere.ClientV2( api_key=self.api_key.resolve_value(), base_url=self.api_base_url, client_name="haystack" ) @@ -85,6 +98,7 @@ def to_dict(self) -> Dict[str, Any]: max_chunks_per_doc=self.max_chunks_per_doc, meta_fields_to_embed=self.meta_fields_to_embed, meta_data_separator=self.meta_data_separator, + max_tokens_per_doc=self.max_tokens_per_doc, ) @classmethod @@ -152,7 +166,7 @@ def run(self, query: str, documents: List[Document], top_k: Optional[int] = None model=self.model_name, query=query, documents=cohere_input_docs, - max_chunks_per_doc=self.max_chunks_per_doc, + max_tokens_per_doc=self.max_tokens_per_doc, top_n=top_k, ) indices = [output.index for output in response.results] diff --git a/integrations/cohere/tests/test_cohere_ranker.py b/integrations/cohere/tests/test_cohere_ranker.py index ff861b39d..34a9d1456 100644 --- a/integrations/cohere/tests/test_cohere_ranker.py +++ b/integrations/cohere/tests/test_cohere_ranker.py @@ -20,7 +20,7 @@ def mock_ranker_response(): RerankResult, RerankResult] """ - with patch("cohere.Client.rerank", autospec=True) as mock_ranker_response: + with patch("cohere.ClientV2.rerank", autospec=True) as mock_ranker_response: mock_response = Mock() @@ -48,6 +48,7 @@ def test_init_default(self, monkeypatch): assert component.max_chunks_per_doc is None assert component.meta_fields_to_embed == [] assert component.meta_data_separator == "\n" + assert component.max_tokens_per_doc == 4096 def test_init_fail_wo_api_key(self, monkeypatch): monkeypatch.delenv("CO_API_KEY", raising=False) @@ -65,6 +66,7 @@ def test_init_with_parameters(self, monkeypatch): max_chunks_per_doc=40, meta_fields_to_embed=["meta_field_1", "meta_field_2"], meta_data_separator=",", + max_tokens_per_doc=100, ) assert component.model_name == "rerank-multilingual-v2.0" assert component.top_k == 5 @@ -73,6 +75,7 @@ def test_init_with_parameters(self, monkeypatch): assert component.max_chunks_per_doc == 40 assert component.meta_fields_to_embed == ["meta_field_1", "meta_field_2"] assert component.meta_data_separator == "," + assert component.max_tokens_per_doc == 100 def test_to_dict_default(self, monkeypatch): monkeypatch.setenv("CO_API_KEY", "test-api-key") @@ -88,6 +91,7 @@ def test_to_dict_default(self, monkeypatch): "max_chunks_per_doc": None, "meta_fields_to_embed": [], "meta_data_separator": "\n", + "max_tokens_per_doc": 4096, }, } @@ -101,6 +105,7 @@ def test_to_dict_with_parameters(self, monkeypatch): max_chunks_per_doc=50, meta_fields_to_embed=["meta_field_1", "meta_field_2"], meta_data_separator=",", + max_tokens_per_doc=100, ) data = component.to_dict() assert data == { @@ -113,6 +118,7 @@ def test_to_dict_with_parameters(self, monkeypatch): "max_chunks_per_doc": 50, "meta_fields_to_embed": ["meta_field_1", "meta_field_2"], "meta_data_separator": ",", + "max_tokens_per_doc": 100, }, } @@ -128,6 +134,7 @@ def test_from_dict(self, monkeypatch): "max_chunks_per_doc": 50, "meta_fields_to_embed": ["meta_field_1", "meta_field_2"], "meta_data_separator": ",", + "max_tokens_per_doc": 100, }, } component = CohereRanker.from_dict(data) @@ -138,6 +145,7 @@ def test_from_dict(self, monkeypatch): assert component.max_chunks_per_doc == 50 assert component.meta_fields_to_embed == ["meta_field_1", "meta_field_2"] assert component.meta_data_separator == "," + assert component.max_tokens_per_doc == 100 def test_from_dict_fail_wo_env_var(self, monkeypatch): monkeypatch.delenv("CO_API_KEY", raising=False) @@ -149,6 +157,7 @@ def test_from_dict_fail_wo_env_var(self, monkeypatch): "api_key": {"env_vars": ["COHERE_API_KEY", "CO_API_KEY"], "strict": True, "type": "env_var"}, "top_k": 2, "max_chunks_per_doc": 50, + "max_tokens_per_doc": 100, }, } with pytest.raises(ValueError, match="None of the following authentication environment variables are set: *"): diff --git a/integrations/cohere/tests/test_document_embedder.py b/integrations/cohere/tests/test_document_embedder.py index d69e1a5a2..895e27c7d 100644 --- a/integrations/cohere/tests/test_document_embedder.py +++ b/integrations/cohere/tests/test_document_embedder.py @@ -8,6 +8,7 @@ from haystack.utils import Secret from haystack_integrations.components.embedders.cohere import CohereDocumentEmbedder +from haystack_integrations.components.embedders.cohere.embedding_types import EmbeddingTypes pytestmark = pytest.mark.embedders COHERE_API_URL = "https://api.cohere.com" @@ -27,6 +28,7 @@ def test_init_default(self): assert embedder.progress_bar is True assert embedder.meta_fields_to_embed == [] assert embedder.embedding_separator == "\n" + assert embedder.embedding_type == EmbeddingTypes.FLOAT def test_init_with_parameters(self): embedder = CohereDocumentEmbedder( @@ -53,6 +55,7 @@ def test_init_with_parameters(self): assert embedder.progress_bar is False assert embedder.meta_fields_to_embed == ["test_field"] assert embedder.embedding_separator == "-" + assert embedder.embedding_type == EmbeddingTypes.FLOAT def test_to_dict(self): embedder_component = CohereDocumentEmbedder() @@ -71,6 +74,7 @@ def test_to_dict(self): "progress_bar": True, "meta_fields_to_embed": [], "embedding_separator": "\n", + "embedding_type": "float", }, } @@ -87,6 +91,7 @@ def test_to_dict_with_custom_init_parameters(self): progress_bar=False, meta_fields_to_embed=["text_field"], embedding_separator="-", + embedding_type=EmbeddingTypes.INT8, ) component_dict = embedder_component.to_dict() assert component_dict == { @@ -103,6 +108,7 @@ def test_to_dict_with_custom_init_parameters(self): "progress_bar": False, "meta_fields_to_embed": ["text_field"], "embedding_separator": "-", + "embedding_type": "int8", }, } @@ -112,7 +118,7 @@ def test_to_dict_with_custom_init_parameters(self): ) @pytest.mark.integration def test_run(self): - embedder = CohereDocumentEmbedder() + embedder = CohereDocumentEmbedder(model="embed-english-v2.0", embedding_type=EmbeddingTypes.FLOAT) docs = [ Document(content="I love cheese", meta={"topic": "Cuisine"}), diff --git a/integrations/cohere/tests/test_text_embedder.py b/integrations/cohere/tests/test_text_embedder.py index 80f7c1a3e..58fff3900 100644 --- a/integrations/cohere/tests/test_text_embedder.py +++ b/integrations/cohere/tests/test_text_embedder.py @@ -7,6 +7,7 @@ from haystack.utils import Secret from haystack_integrations.components.embedders.cohere import CohereTextEmbedder +from haystack_integrations.components.embedders.cohere.embedding_types import EmbeddingTypes pytestmark = pytest.mark.embedders COHERE_API_URL = "https://api.cohere.com" @@ -47,6 +48,7 @@ def test_init_with_parameters(self): assert embedder.truncate == "START" assert embedder.use_async_client is True assert embedder.timeout == 60 + assert embedder.embedding_type == EmbeddingTypes.FLOAT def test_to_dict(self): """ @@ -64,6 +66,7 @@ def test_to_dict(self): "truncate": "END", "use_async_client": False, "timeout": 120, + "embedding_type": "float", }, } @@ -79,6 +82,7 @@ def test_to_dict_with_custom_init_parameters(self): truncate="START", use_async_client=True, timeout=60, + embedding_type=EmbeddingTypes.INT8, ) component_dict = embedder_component.to_dict() assert component_dict == { @@ -91,6 +95,7 @@ def test_to_dict_with_custom_init_parameters(self): "truncate": "START", "use_async_client": True, "timeout": 60, + "embedding_type": "int8", }, } From ea63bab83e1181a725855703ebc62ae22894d450 Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Tue, 28 Jan 2025 13:33:14 +0100 Subject: [PATCH 221/229] feat: CohereChatGenerator - add tools support (#1318) * CohereChatGenerator v2 upgrade + tools --- integrations/cohere/pyproject.toml | 1 + .../generators/cohere/chat/chat_generator.py | 406 +++++++++++++----- .../tests/test_cohere_chat_generator.py | 319 +++++++++++--- .../cohere/tests/test_cohere_generator.py | 3 + 4 files changed, 560 insertions(+), 169 deletions(-) diff --git a/integrations/cohere/pyproject.toml b/integrations/cohere/pyproject.toml index 262b1612d..c7a8fcdc2 100644 --- a/integrations/cohere/pyproject.toml +++ b/integrations/cohere/pyproject.toml @@ -47,6 +47,7 @@ dependencies = [ "pytest", "pytest-rerunfailures", "haystack-pydoc-tools", + "jsonschema" # for tools ] [tool.hatch.envs.default.scripts] test = "pytest {args:tests}" diff --git a/integrations/cohere/src/haystack_integrations/components/generators/cohere/chat/chat_generator.py b/integrations/cohere/src/haystack_integrations/components/generators/cohere/chat/chat_generator.py index 33e7c98f6..d169ddf6d 100644 --- a/integrations/cohere/src/haystack_integrations/components/generators/cohere/chat/chat_generator.py +++ b/integrations/cohere/src/haystack_integrations/components/generators/cohere/chat/chat_generator.py @@ -1,37 +1,290 @@ +import json import logging -from typing import Any, Callable, Dict, List, Optional +from typing import Any, Callable, Dict, Generator, List, Optional from haystack import component, default_from_dict, default_to_dict -from haystack.dataclasses import ChatMessage, ChatRole, StreamingChunk +from haystack.dataclasses import ChatMessage, StreamingChunk, ToolCall from haystack.lazy_imports import LazyImport +from haystack.tools import Tool, _check_duplicate_tool_names, deserialize_tools_inplace from haystack.utils import Secret, deserialize_secrets_inplace from haystack.utils.callable_serialization import deserialize_callable, serialize_callable +from cohere import ChatResponse + with LazyImport(message="Run 'pip install cohere'") as cohere_import: import cohere + logger = logging.getLogger(__name__) +def _format_tool(tool: Tool) -> Dict[str, Any]: + """ + Formats a Haystack Tool into Cohere's function specification format. + + The function transforms the tool's properties (name, description, parameters) + into the structure expected by Cohere's API. + + :param tool: The Haystack Tool to format. + :return: Dictionary formatted according to Cohere's function specification. + """ + return { + "type": "function", + "function": {"name": tool.name, "description": tool.description, "parameters": tool.parameters}, + } + + +def _format_message(message: ChatMessage) -> Dict[str, Any]: + """ + Formats a Haystack ChatMessage into Cohere's chat format. + + The function handles message components including: + - Text content + - Tool calls + - Tool call results + + :param message: Haystack ChatMessage to format. + :return: Dictionary formatted according to Cohere's chat specification. + """ + if not message.texts and not message.tool_calls and not message.tool_call_results: + msg = "A `ChatMessage` must contain at least one `TextContent`, `ToolCall`, or `ToolCallResult`." + raise ValueError(msg) + + cohere_msg: Dict[str, Any] = {"role": message.role.value} + + # Format the message based on its content type + if message.tool_call_results: + result = message.tool_call_results[0] # We expect one result at a time + if result.origin.id is None: + msg = "`ToolCall` must have a non-null `id` attribute to be used with Cohere." + raise ValueError(msg) + cohere_msg.update( + { + "role": "tool", + "tool_call_id": result.origin.id, + "content": [{"type": "document", "document": {"data": json.dumps({"result": result.result})}}], + } + ) + elif message.tool_calls: + tool_calls = [] + for tool_call in message.tool_calls: + if tool_call.id is None: + msg = "`ToolCall` must have a non-null `id` attribute to be used with Cohere." + raise ValueError(msg) + tool_calls.append( + { + "id": tool_call.id, + "type": "function", + "function": {"name": tool_call.tool_name, "arguments": json.dumps(tool_call.arguments)}, + } + ) + cohere_msg.update( + { + "tool_calls": tool_calls, + "tool_plan": message.text if message.text else "", + } + ) + else: + cohere_msg["content"] = ( + [{"type": "text", "text": message.texts[0]}] if message.texts and message.texts[0] else [] + ) + if not cohere_msg["content"]: + msg = "A `ChatMessage` must contain at least one `TextContent`, `ToolCall`, or `ToolCallResult`." + raise ValueError(msg) + + return cohere_msg + + +def _parse_response(chat_response: ChatResponse, model: str) -> ChatMessage: + """ + Parses Cohere's chat response into a Haystack ChatMessage. + + Extracts and organizes various response components including: + - Text content + - Tool calls + - Usage statistics + - Citations + - Metadata + + :param chat_response: Response from Cohere's chat API. + :param model: The name of the model that generated the response. + :return: A Haystack ChatMessage containing the formatted response. + """ + if chat_response.message.tool_calls: + # Convert Cohere tool calls to Haystack ToolCall objects + tool_calls = [ + ToolCall(id=tc.id, tool_name=tc.function.name, arguments=json.loads(tc.function.arguments)) + for tc in chat_response.message.tool_calls + ] + # Create message with tool plan as text and tool calls in the format Haystack expects + tool_plan = chat_response.message.tool_plan or "" + message = ChatMessage.from_assistant(text=tool_plan, tool_calls=tool_calls) + elif chat_response.message.content: + message = ChatMessage.from_assistant(chat_response.message.content[0].text) + else: + # Handle the case where neither tool_calls nor content exists + logger.warning(f"Received empty response from Cohere API: {chat_response.message}") + message = ChatMessage.from_assistant("") + + # In V2, token usage is part of the response object, not the message + message.meta.update( + { + "model": model, + "usage": { + "prompt_tokens": (chat_response.usage.billed_units.input_tokens), + "completion_tokens": (chat_response.usage.billed_units.output_tokens), + }, + "index": 0, + "finish_reason": chat_response.finish_reason, + "citations": chat_response.message.citations, + } + ) + return message + + +def _parse_streaming_response( + response: Generator, model: str, streaming_callback: Callable[[StreamingChunk], None] +) -> ChatMessage: + """ + Parses Cohere's streaming chat response into a Haystack ChatMessage. + + Processes streaming chunks and aggregates them into a complete response, + including: + - Text content + - Tool plan + - Tool calls and their arguments + - Usage statistics + - Finish reason + + :param response: Streaming response from Cohere's chat API. + :param model: The name of the model that generated the response. + :param streaming_callback: Callback function for streaming chunks. + :return: A Haystack ChatMessage containing the formatted response. + """ + response_text = "" + tool_plan = "" + tool_calls = [] + current_tool_call = None + current_tool_arguments = "" + captured_meta = {} + + for chunk in response: + if chunk and chunk.type == "content-delta": + stream_chunk = StreamingChunk(content=chunk.delta.message.content.text) + streaming_callback(stream_chunk) + response_text += chunk.delta.message.content.text + elif chunk and chunk.type == "tool-plan-delta": + tool_plan += chunk.delta.message.tool_plan + stream_chunk = StreamingChunk(content=chunk.delta.message.tool_plan) + streaming_callback(stream_chunk) + elif chunk and chunk.type == "tool-call-start": + tool_call = chunk.delta.message.tool_calls + current_tool_call = ToolCall(id=tool_call.id, tool_name=tool_call.function.name, arguments="") + elif chunk and chunk.type == "tool-call-delta": + current_tool_arguments += chunk.delta.message.tool_calls.function.arguments + elif chunk and chunk.type == "tool-call-end": + if current_tool_call: + current_tool_call.arguments = json.loads(current_tool_arguments) + tool_calls.append(current_tool_call) + current_tool_call = None + current_tool_arguments = "" + elif chunk and chunk.type == "message-end": + captured_meta.update( + { + "model": model, + "index": 0, + "finish_reason": chunk.delta.finish_reason, + "usage": { + "prompt_tokens": chunk.delta.usage.billed_units.input_tokens, + "completion_tokens": chunk.delta.usage.billed_units.output_tokens, + }, + } + ) + + # Create the appropriate ChatMessage based on what we received + if tool_calls: + chat_message = ChatMessage.from_assistant(text=tool_plan, tool_calls=tool_calls) + else: + chat_message = ChatMessage.from_assistant(text=response_text) + + # Add metadata + chat_message.meta.update(captured_meta) + + return chat_message + + @component class CohereChatGenerator: """ - Completes chats using Cohere's models through Cohere `chat` endpoint. + Completes chats using Cohere's models using Cohere cohere.ClientV2 `chat` endpoint. - You can customize how the text is generated by passing parameters to the + You can customize how the chat response is generated by passing parameters to the Cohere API through the `**generation_kwargs` parameter. You can do this when initializing or running the component. Any parameter that works with - `cohere.Client.chat` will work here too. + `cohere.ClientV2.chat` will work here too. For details, see [Cohere API](https://docs.cohere.com/reference/chat). - ### Usage example + Below is an example of how to use the component: + + ### Simple example + ```python + from haystack.dataclasses import ChatMessage + from haystack.utils import Secret + from haystack_integrations.components.generators.cohere import CohereChatGenerator + + client = CohereChatGenerator(model="command-r", api_key=Secret.from_env_var("COHERE_API_KEY")) + messages = [ChatMessage.from_user("What's Natural Language Processing?")] + client.run(messages) + + # Output: {'replies': [ChatMessage(_role=, + # _content=[TextContent(text='Natural Language Processing (NLP) is an interdisciplinary... + ``` + + ### Advanced example + + CohereChatGenerator can be integrated into pipelines and supports Haystack's tooling + architecture, enabling tools to be invoked seamlessly across various generators. ```python + from haystack import Pipeline + from haystack.dataclasses import ChatMessage + from haystack.components.tools import ToolInvoker + from haystack.tools import Tool from haystack_integrations.components.generators.cohere import CohereChatGenerator - component = CohereChatGenerator(api_key=Secret.from_token("test-api-key")) - response = component.run(chat_messages) + # Create a weather tool + def weather(city: str) -> str: + return f"The weather in {city} is sunny and 32°C" + + weather_tool = Tool( + name="weather", + description="useful to determine the weather in a given location", + parameters={ + "type": "object", + "properties": { + "city": { + "type": "string", + "description": "The name of the city to get weather for, e.g. Paris, London", + } + }, + "required": ["city"], + }, + function=weather, + ) + + # Create and set up the pipeline + pipeline = Pipeline() + pipeline.add_component("generator", CohereChatGenerator(model="command-r", tools=[weather_tool])) + pipeline.add_component("tool_invoker", ToolInvoker(tools=[weather_tool])) + pipeline.connect("generator", "tool_invoker") + + # Run the pipeline with a weather query + results = pipeline.run( + data={"generator": {"messages": [ChatMessage.from_user("What's the weather like in Paris?")]}} + ) - assert response["replies"] + # The tool result will be available in the pipeline output + print(results["tool_invoker"]["tool_messages"][0].tool_call_result.result) + # Output: "The weather in Paris is sunny and 32°C" ``` """ @@ -42,6 +295,7 @@ def __init__( streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, api_base_url: Optional[str] = None, generation_kwargs: Optional[Dict[str, Any]] = None, + tools: Optional[List[Tool]] = None, **kwargs, ): """ @@ -56,29 +310,18 @@ def __init__( :param generation_kwargs: Other parameters to use for the model during generation. For a list of parameters, see [Cohere Chat endpoint](https://docs.cohere.com/reference/chat). Some of the parameters are: - - 'chat_history': A list of previous messages between the user and the model, meant to give the model - conversational context for responding to the user's message. - - 'preamble': When specified, replaces the default Cohere preamble with the provided one. - - 'conversation_id': An alternative to `chat_history`. Previous conversations can be resumed by providing - the conversation's identifier. The contents of message and the model's response are stored - as part of this conversation. If a conversation with this ID doesn't exist, - a new conversation is created. - - 'prompt_truncation': Defaults to `AUTO` when connectors are specified and to `OFF` in all other cases. - Dictates how the prompt is constructed. - - 'connectors': Accepts {"id": "web-search"}, and the "id" for a custom connector, if you created one. - When specified, the model's reply is enriched with information found by - quering each of the connectors (RAG). - - 'documents': A list of relevant documents that the model can use to enrich its reply. - - 'search_queries_only': Defaults to `False`. When `True`, the response only contains a - list of generated search queries, but no search takes place, and no reply from the model to the - user's message is generated. + - 'messages': A list of messages between the user and the model, meant to give the model + conversational context for responding to the user's message. + - 'system_message': When specified, adds a system message at the beginning of the conversation. - 'citation_quality': Defaults to `accurate`. Dictates the approach taken to generating citations - as part of the RAG flow by allowing the user to specify whether they want - `accurate` results or `fast` results. + as part of the RAG flow by allowing the user to specify whether they want + `accurate` results or `fast` results. - 'temperature': A non-negative float that tunes the degree of randomness in generation. Lower temperatures - mean less random generations. + mean less random generations. + :param tools: A list of Tool objects that the model can use. Each tool should have a unique name. """ cohere_import.check() + _check_duplicate_tool_names(tools) if not api_base_url: api_base_url = "https://api.cohere.com" @@ -89,8 +332,9 @@ def __init__( self.streaming_callback = streaming_callback self.api_base_url = api_base_url self.generation_kwargs = generation_kwargs + self.tools = tools self.model_parameters = kwargs - self.client = cohere.Client( + self.client = cohere.ClientV2( api_key=self.api_key.resolve_value(), base_url=self.api_base_url, client_name="haystack" ) @@ -108,6 +352,7 @@ def to_dict(self) -> Dict[str, Any]: Dictionary with serialized data. """ callback_name = serialize_callable(self.streaming_callback) if self.streaming_callback else None + serialized_tools = [tool.to_dict() for tool in self.tools] if self.tools else None return default_to_dict( self, model=self.model, @@ -115,6 +360,7 @@ def to_dict(self) -> Dict[str, Any]: api_base_url=self.api_base_url, api_key=self.api_key.to_dict(), generation_kwargs=self.generation_kwargs, + tools=serialized_tools, ) @classmethod @@ -129,108 +375,56 @@ def from_dict(cls, data: Dict[str, Any]) -> "CohereChatGenerator": """ init_params = data.get("init_parameters", {}) deserialize_secrets_inplace(init_params, ["api_key"]) + deserialize_tools_inplace(init_params, key="tools") serialized_callback_handler = init_params.get("streaming_callback") if serialized_callback_handler: data["init_parameters"]["streaming_callback"] = deserialize_callable(serialized_callback_handler) return default_from_dict(cls, data) - def _message_to_dict(self, message: ChatMessage) -> Dict[str, str]: - role = "User" if message.role == ChatRole.USER else "Chatbot" - chat_message = {"user_name": role, "text": message.text} - return chat_message - @component.output_types(replies=List[ChatMessage]) - def run(self, messages: List[ChatMessage], generation_kwargs: Optional[Dict[str, Any]] = None): + def run( + self, + messages: List[ChatMessage], + generation_kwargs: Optional[Dict[str, Any]] = None, + tools: Optional[List[Tool]] = None, + ): """ - Invoke the text generation inference based on the provided messages and generation parameters. + Invoke the chat endpoint based on the provided messages and generation parameters. :param messages: list of `ChatMessage` instances representing the input messages. - :param generation_kwargs: additional keyword arguments for text generation. These parameters will + :param generation_kwargs: additional keyword arguments for chat generation. These parameters will potentially override the parameters passed in the __init__ method. For more details on the parameters supported by the Cohere API, refer to the Cohere [documentation](https://docs.cohere.com/reference/chat). + :param tools: A list of tools for which the model can prepare calls. If set, it will override + the `tools` parameter set during component initialization. :returns: A dictionary with the following keys: - `replies`: a list of `ChatMessage` instances representing the generated responses. """ # update generation kwargs by merging with the generation kwargs passed to the run method generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})} - chat_history = [self._message_to_dict(m) for m in messages[:-1]] + + # Handle tools + tools = tools or self.tools + if tools: + _check_duplicate_tool_names(tools) + generation_kwargs["tools"] = [_format_tool(tool) for tool in tools] + + formatted_messages = [_format_message(message) for message in messages] + if self.streaming_callback: response = self.client.chat_stream( - message=messages[-1].text, model=self.model, - chat_history=chat_history, + messages=formatted_messages, **generation_kwargs, ) - - response_text = "" - finish_response = None - for event in response: - if event.event_type == "text-generation": - stream_chunk = self._build_chunk(event) - self.streaming_callback(stream_chunk) - response_text += event.text - elif event.event_type == "stream-end": - finish_response = event.response - chat_message = ChatMessage.from_assistant(response_text) - - if finish_response and finish_response.meta: - if finish_response.meta.billed_units: - tokens_in = finish_response.meta.billed_units.input_tokens or -1 - tokens_out = finish_response.meta.billed_units.output_tokens or -1 - chat_message.meta["usage"] = {"prompt_tokens": tokens_in, "completion_tokens": tokens_out} - chat_message.meta.update( - { - "model": self.model, - "index": 0, - "finish_reason": finish_response.finish_reason, - "documents": finish_response.documents, - "citations": finish_response.citations, - } - ) + chat_message = _parse_streaming_response(response, self.model, self.streaming_callback) else: response = self.client.chat( - message=messages[-1].text, model=self.model, - chat_history=chat_history, + messages=formatted_messages, **generation_kwargs, ) - chat_message = self._build_message(response) - return {"replies": [chat_message]} + chat_message = _parse_response(response, self.model) - def _build_chunk(self, chunk) -> StreamingChunk: - """ - Converts the response from the Cohere API to a StreamingChunk. - :param chunk: The chunk returned by the OpenAI API. - :param choice: The choice returned by the OpenAI API. - :returns: The StreamingChunk. - """ - chat_message = StreamingChunk(content=chunk.text, meta={"event_type": chunk.event_type}) - return chat_message - - def _build_message(self, cohere_response): - """ - Converts the non-streaming response from the Cohere API to a ChatMessage. - :param cohere_response: The completion returned by the Cohere API. - :returns: The ChatMessage. - """ - message = None - if cohere_response.tool_calls: - # TODO revisit to see if we need to handle multiple tool calls - message = ChatMessage.from_assistant(cohere_response.tool_calls[0].json()) - elif cohere_response.text: - message = ChatMessage.from_assistant(cohere_response.text) - message.meta.update( - { - "model": self.model, - "usage": { - "prompt_tokens": cohere_response.meta.billed_units.input_tokens, - "completion_tokens": cohere_response.meta.billed_units.output_tokens, - }, - "index": 0, - "finish_reason": cohere_response.finish_reason, - "documents": cohere_response.documents, - "citations": cohere_response.citations, - } - ) - return message + return {"replies": [chat_message]} diff --git a/integrations/cohere/tests/test_cohere_chat_generator.py b/integrations/cohere/tests/test_cohere_chat_generator.py index 05a18f074..476a4ae32 100644 --- a/integrations/cohere/tests/test_cohere_chat_generator.py +++ b/integrations/cohere/tests/test_cohere_chat_generator.py @@ -1,11 +1,13 @@ -import json import os from unittest.mock import Mock import pytest from cohere.core import ApiError +from haystack import Pipeline from haystack.components.generators.utils import print_streaming_chunk +from haystack.components.tools import ToolInvoker from haystack.dataclasses import ChatMessage, ChatRole, StreamingChunk +from haystack.tools import Tool from haystack.utils import Secret from haystack_integrations.components.generators.cohere import CohereChatGenerator @@ -13,6 +15,14 @@ pytestmark = pytest.mark.chat_generators +def weather(city: str) -> str: + return f"The weather in {city} is sunny and 32°C" + + +def stock_price(ticker: str): + return f"The current price of {ticker} is $100" + + def streaming_chunk(text: str): """ Mock chunks of streaming responses from the Cohere API @@ -73,6 +83,7 @@ def test_to_dict_default(self, monkeypatch): "api_key": {"env_vars": ["COHERE_API_KEY", "CO_API_KEY"], "strict": True, "type": "env_var"}, "api_base_url": "https://api.cohere.com", "generation_kwargs": {}, + "tools": None, }, } @@ -95,6 +106,7 @@ def test_to_dict_with_parameters(self, monkeypatch): "streaming_callback": "haystack.components.generators.utils.print_streaming_chunk", "api_base_url": "test-base-url", "generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"}, + "tools": None, }, } @@ -133,11 +145,6 @@ def test_from_dict_fail_wo_env_var(self, monkeypatch): with pytest.raises(ValueError): CohereChatGenerator.from_dict(data) - def test_message_to_dict(self, chat_messages): - obj = CohereChatGenerator(api_key=Secret.from_token("test-api-key")) - dictionary = [obj._message_to_dict(message) for message in chat_messages] - assert dictionary == [{"user_name": "Chatbot", "text": "What's the capital of France"}] - @pytest.mark.skipif( not os.environ.get("COHERE_API_KEY", None) and not os.environ.get("CO_API_KEY", None), reason="Export an env var called COHERE_API_KEY/CO_API_KEY containing the Cohere API key to run this test.", @@ -186,12 +193,9 @@ def __call__(self, chunk: StreamingChunk) -> None: assert len(results["replies"]) == 1 message: ChatMessage = results["replies"][0] assert "Paris" in message.text - assert message.meta["finish_reason"] == "COMPLETE" - assert callback.counter > 1 assert "Paris" in callback.responses - assert "usage" in message.meta assert "prompt_tokens" in message.meta["usage"] assert "completion_tokens" in message.meta["usage"] @@ -201,80 +205,269 @@ def __call__(self, chunk: StreamingChunk) -> None: reason="Export an env var called COHERE_API_KEY/CO_API_KEY containing the Cohere API key to run this test.", ) @pytest.mark.integration - def test_live_run_with_connector(self): - chat_messages = [ChatMessage.from_user("What's the capital of France")] - component = CohereChatGenerator(generation_kwargs={"temperature": 0.8}) - results = component.run(chat_messages, generation_kwargs={"connectors": [{"id": "web-search"}]}) - assert len(results["replies"]) == 1 - message: ChatMessage = results["replies"][0] - assert "Paris" in message.text - assert message.meta["documents"] is not None - assert "citations" in message.meta # Citations might be None + def test_tools_use_old_way(self): + # See https://docs.cohere.com/docs/structured-outputs-json for more information + tools_schema = [ + { + "type": "function", + "function": { + "name": "get_stock_price", + "description": "Retrieves the current stock price for a given ticker symbol.", + "parameters": { + "type": "object", + "properties": { + "ticker": { + "type": "string", + "description": "The stock ticker symbol, e.g. AAPL for Apple Inc.", + } + }, + "required": ["ticker"], + }, + }, + } + ] + client = CohereChatGenerator(model="command-r") + response = client.run( + messages=[ChatMessage.from_user("What is the current price of AAPL?")], + generation_kwargs={"tools": tools_schema}, + ) + replies = response["replies"] + assert isinstance(replies, list), "Replies is not a list" + assert len(replies) > 0, "No replies received" + + first_reply = replies[0] + assert isinstance(first_reply, ChatMessage), "First reply is not a ChatMessage instance" + assert first_reply.text, "First reply text should be a tool plan" + assert ChatMessage.is_from(first_reply, ChatRole.ASSISTANT), "First reply is not from the assistant" + + assert first_reply.tool_calls, "First reply has no tool calls" + assert len(first_reply.tool_calls) == 1, "First reply has more than one tool call" + assert first_reply.tool_calls[0].tool_name == "get_stock_price", "First tool call is not get_stock_price" + assert first_reply.tool_calls[0].arguments == {"ticker": "AAPL"}, "First tool call arguments are not correct" @pytest.mark.skipif( not os.environ.get("COHERE_API_KEY", None) and not os.environ.get("CO_API_KEY", None), reason="Export an env var called COHERE_API_KEY/CO_API_KEY containing the Cohere API key to run this test.", ) @pytest.mark.integration - def test_live_run_streaming_with_connector(self): - class Callback: - def __init__(self): - self.responses = "" - self.counter = 0 + def test_tools_use_with_tools(self): + stock_price_tool = Tool( + name="get_stock_price", + description="Retrieves the current stock price for a given ticker symbol.", + parameters={ + "type": "object", + "properties": { + "ticker": { + "type": "string", + "description": "The stock ticker symbol, e.g. AAPL for Apple Inc.", + } + }, + "required": ["ticker"], + }, + function=stock_price, + ) + initial_messages = [ChatMessage.from_user("What is the current price of AAPL?")] + client = CohereChatGenerator(model="command-r") + response = client.run( + messages=initial_messages, + tools=[stock_price_tool], + ) + replies = response["replies"] + assert isinstance(replies, list), "Replies is not a list" + assert len(replies) > 0, "No replies received" - def __call__(self, chunk: StreamingChunk) -> None: - self.counter += 1 - self.responses += chunk.content if chunk.content else "" + first_reply = replies[0] + assert isinstance(first_reply, ChatMessage), "First reply is not a ChatMessage instance" + assert first_reply.text, "First reply text should be a tool plan" + assert ChatMessage.is_from(first_reply, ChatRole.ASSISTANT), "First reply is not from the assistant" - callback = Callback() - chat_messages = [ChatMessage.from_user("What's the capital of France? answer in a word")] - component = CohereChatGenerator(streaming_callback=callback) - results = component.run(chat_messages, generation_kwargs={"connectors": [{"id": "web-search"}]}) + assert first_reply.tool_calls, "First reply has no tool calls" + assert len(first_reply.tool_calls) == 1, "First reply has more than one tool call" + assert first_reply.tool_calls[0].tool_name == "get_stock_price", "First tool call is not get_stock_price" + assert first_reply.tool_calls[0].arguments == {"ticker": "AAPL"}, "First tool call arguments are not correct" + + # Test with tool result + new_messages = [ + initial_messages[0], + first_reply, + ChatMessage.from_tool(tool_result="150.23", origin=first_reply.tool_calls[0]), + ] + results = client.run(new_messages) assert len(results["replies"]) == 1 - message: ChatMessage = results["replies"][0] - assert "Paris" in message.text + final_message = results["replies"][0] + assert not final_message.tool_calls + assert len(final_message.text) > 0 + assert "150.23" in final_message.text - assert message.meta["finish_reason"] == "COMPLETE" + @pytest.mark.skipif( + not os.environ.get("COHERE_API_KEY", None) and not os.environ.get("CO_API_KEY", None), + reason="Export an env var called COHERE_API_KEY/CO_API_KEY containing the Cohere API key to run this test.", + ) + @pytest.mark.integration + def test_live_run_with_tools_streaming(self): + """ + Test that the CohereChatGenerator can run with tools and streaming callback. + """ + weather_tool = Tool( + name="weather", + description="useful to determine the weather in a given location", + parameters={ + "type": "object", + "properties": { + "city": { + "type": "string", + "description": "The name of the city to get weather for, e.g. Paris, London", + } + }, + "required": ["city"], + }, + function=weather, + ) - assert "Paris" in callback.responses + initial_messages = [ChatMessage.from_user("What's the weather like in Paris?")] + component = CohereChatGenerator( + model="command-r", # Cohere's model that supports tools + tools=[weather_tool], + streaming_callback=print_streaming_chunk, + ) + results = component.run(messages=initial_messages) + + assert len(results["replies"]) > 0, "No replies received" + first_reply = results["replies"][0] - assert message.meta["documents"] is not None - assert message.meta["citations"] is not None + assert isinstance(first_reply, ChatMessage), "Reply is not a ChatMessage instance" + assert ChatMessage.is_from(first_reply, ChatRole.ASSISTANT), "Reply is not from the assistant" + assert first_reply.tool_calls, "No tool calls in the reply" + + tool_call = first_reply.tool_calls[0] + assert tool_call.tool_name == "weather" + assert tool_call.arguments == {"city": "Paris"} + + # Test with tool result + new_messages = [ + initial_messages[0], + first_reply, + ChatMessage.from_tool(tool_result="22° C", origin=tool_call), + ] + results = component.run(new_messages) + + assert len(results["replies"]) == 1 + final_message = results["replies"][0] + assert not final_message.tool_calls + assert len(final_message.text) > 0 + assert "paris" in final_message.text.lower() @pytest.mark.skipif( not os.environ.get("COHERE_API_KEY", None) and not os.environ.get("CO_API_KEY", None), reason="Export an env var called COHERE_API_KEY/CO_API_KEY containing the Cohere API key to run this test.", ) @pytest.mark.integration - def test_tools_use(self): - # See https://docs.anthropic.com/en/docs/tool-use for more information - tools_schema = { - "name": "get_stock_price", - "description": "Retrieves the current stock price for a given ticker symbol.", - "parameter_definitions": { - "ticker": { - "type": "string", - "description": "The stock ticker symbol, e.g. AAPL for Apple Inc.", - "required": True, + def test_pipeline_with_cohere_chat_generator(self): + """ + Test that the CohereChatGenerator component can be used in a pipeline + """ + weather_tool = Tool( + name="weather", + description="useful to determine the weather in a given location", + parameters={ + "type": "object", + "properties": { + "city": { + "type": "string", + "description": "The name of the city to get weather for, e.g. Paris, London", + } + }, + "required": ["city"], + }, + function=weather, + ) + + pipeline = Pipeline() + pipeline.add_component("generator", CohereChatGenerator(model="command-r", tools=[weather_tool])) + pipeline.add_component("tool_invoker", ToolInvoker(tools=[weather_tool])) + + pipeline.connect("generator", "tool_invoker") + + results = pipeline.run( + data={"generator": {"messages": [ChatMessage.from_user("What's the weather like in Paris?")]}} + ) + + assert ( + "The weather in Paris is sunny and 32°C" + == results["tool_invoker"]["tool_messages"][0].tool_call_result.result + ) + + def test_serde_in_pipeline(self, monkeypatch): + """ + Test serialization/deserialization of CohereChatGenerator in a Pipeline, + including detailed dictionary validation + """ + # Set mock Cohere API key + monkeypatch.setenv("COHERE_API_KEY", "test-api-key") + + # Create a test tool + tool = Tool( + name="weather", + description="useful to determine the weather in a given location", + parameters={"city": {"type": "string"}}, + function=weather, + ) + + # Create generator with specific configuration + generator = CohereChatGenerator( + model="command-r", + generation_kwargs={"temperature": 0.7}, + streaming_callback=print_streaming_chunk, + tools=[tool], + ) + + # Create and configure pipeline + pipeline = Pipeline() + pipeline.add_component("generator", generator) + + # Get pipeline dictionary and verify its structure + pipeline_dict = pipeline.to_dict() + assert pipeline_dict == { + "metadata": {}, + "max_runs_per_component": 100, + "components": { + "generator": { + "type": "haystack_integrations.components.generators.cohere.chat.chat_generator.CohereChatGenerator", # noqa: E501 + "init_parameters": { + "model": "command-r", + "api_key": {"type": "env_var", "env_vars": ["COHERE_API_KEY", "CO_API_KEY"], "strict": True}, + "streaming_callback": "haystack.components.generators.utils.print_streaming_chunk", + "api_base_url": "https://api.cohere.com", + "generation_kwargs": {"temperature": 0.7}, + "tools": [ + { + "type": "haystack.tools.tool.Tool", + "data": { + "name": "weather", + "description": "useful to determine the weather in a given location", + "parameters": {"city": {"type": "string"}}, + "function": "tests.test_cohere_chat_generator.weather", + }, + } + ], + }, } }, + "connections": [], } - client = CohereChatGenerator(model="command-r") - response = client.run( - messages=[ChatMessage.from_user("What is the current price of AAPL?")], - generation_kwargs={"tools": [tools_schema]}, - ) - replies = response["replies"] - assert isinstance(replies, list), "Replies is not a list" - assert len(replies) > 0, "No replies received" - first_reply = replies[0] - assert isinstance(first_reply, ChatMessage), "First reply is not a ChatMessage instance" - assert first_reply.text, "First reply has no text" - assert ChatMessage.is_from(first_reply, ChatRole.ASSISTANT), "First reply is not from the assistant" - assert "get_stock_price" in first_reply.text.lower(), "First reply does not contain get_stock_price" - assert first_reply.meta, "First reply has no metadata" - fc_response = json.loads(first_reply.text) - assert "name" in fc_response, "First reply does not contain name of the tool" - assert "parameters" in fc_response, "First reply does not contain parameters of the tool" + # Test YAML serialization/deserialization + pipeline_yaml = pipeline.dumps() + new_pipeline = Pipeline.loads(pipeline_yaml) + assert new_pipeline == pipeline + + # Verify the loaded pipeline's generator has the same configuration + loaded_generator = new_pipeline.get_component("generator") + assert loaded_generator.model == generator.model + assert loaded_generator.generation_kwargs == generator.generation_kwargs + assert loaded_generator.streaming_callback == generator.streaming_callback + assert len(loaded_generator.tools) == len(generator.tools) + assert loaded_generator.tools[0].name == generator.tools[0].name + assert loaded_generator.tools[0].description == generator.tools[0].description + assert loaded_generator.tools[0].parameters == generator.tools[0].parameters diff --git a/integrations/cohere/tests/test_cohere_generator.py b/integrations/cohere/tests/test_cohere_generator.py index fffe872f5..37efb7e2c 100644 --- a/integrations/cohere/tests/test_cohere_generator.py +++ b/integrations/cohere/tests/test_cohere_generator.py @@ -52,6 +52,7 @@ def test_to_dict_default(self, monkeypatch): "streaming_callback": None, "api_base_url": COHERE_API_URL, "generation_kwargs": {}, + "tools": None, }, } @@ -75,6 +76,7 @@ def test_to_dict_with_parameters(self, monkeypatch): "api_key": {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"}, "streaming_callback": "haystack.components.generators.utils.print_streaming_chunk", "generation_kwargs": {}, + "tools": None, }, } @@ -90,6 +92,7 @@ def test_from_dict(self, monkeypatch): "some_test_param": "test-params", "api_base_url": "test-base-url", "streaming_callback": "haystack.components.generators.utils.print_streaming_chunk", + "tools": None, }, } component: CohereGenerator = CohereGenerator.from_dict(data) From a9179bd731b474ef81ac9afec9d237debc9a45bc Mon Sep 17 00:00:00 2001 From: HaystackBot Date: Tue, 28 Jan 2025 12:35:06 +0000 Subject: [PATCH 222/229] Update the changelog --- integrations/cohere/CHANGELOG.md | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/integrations/cohere/CHANGELOG.md b/integrations/cohere/CHANGELOG.md index 1300a3efa..42cdaaf8b 100644 --- a/integrations/cohere/CHANGELOG.md +++ b/integrations/cohere/CHANGELOG.md @@ -1,5 +1,16 @@ # Changelog +## [integrations/cohere-v3.0.0] - 2025-01-28 + +### 🚀 Features + +- CohereChatGenerator - add tools support (#1318) + +### 🚜 Refactor + +- Migrate Cohere to V2 (#1321) + + ## [integrations/cohere-v2.0.2] - 2025-01-15 ### 🐛 Bug Fixes From e4f462f99edeede7505b9ddc441c8629009d3a9f Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Wed, 29 Jan 2025 10:10:26 +0100 Subject: [PATCH 223/229] Automatically open PR for subproject CHANGELOG.md updates (#1328) --- .github/workflows/CI_pypi_release.yml | 24 +++++++++++++++++------- 1 file changed, 17 insertions(+), 7 deletions(-) diff --git a/.github/workflows/CI_pypi_release.yml b/.github/workflows/CI_pypi_release.yml index 1162ca3b3..4f4080619 100644 --- a/.github/workflows/CI_pypi_release.yml +++ b/.github/workflows/CI_pypi_release.yml @@ -57,11 +57,21 @@ jobs: --include-path "${{ steps.pathfinder.outputs.project_path }}/**/*" --tag-pattern "${{ steps.pathfinder.outputs.project_path }}-v*" - - name: Commit changelog - uses: EndBug/add-and-commit@v9 + - name: Create Pull Request + uses: peter-evans/create-pull-request@v7 with: - author_name: "HaystackBot" - author_email: "accounts@deepset.ai" - message: "Update the changelog" - add: ${{ steps.pathfinder.outputs.project_path }} - push: origin HEAD:main + token: ${{ secrets.HAYSTACK_BOT_TOKEN }} + commit-message: "Update changelog for ${{ steps.pathfinder.outputs.project_path }}" + branch: update-changelog-${{ steps.pathfinder.outputs.project_path }} + title: "docs: update changelog for ${{ steps.pathfinder.outputs.project_path }}" + add-paths: | + ${{ steps.pathfinder.outputs.project_path }}/CHANGELOG.md + body: | + This PR updates the changelog for ${{ steps.pathfinder.outputs.project_path }} integration + with the latest changes just released on PyPi. Please review the changelog diff below and adjust it + if necessary. + + A good changelog diff simply lists these latest changes on top of the CHANGELOG.md file. + If there are some diffs that seem out of place, please adjust the CHANGELOG.md file manually. + Either way, please merge this PR as soon as possible. + From 97742c1df41f3ea7c2748ada94a18c98ef452d68 Mon Sep 17 00:00:00 2001 From: Stefano Fiorucci Date: Wed, 29 Jan 2025 14:44:51 +0100 Subject: [PATCH 224/229] feat!: Google Vertex - support for Tool + general refactoring (#1327) * progress * more progress * fixes + new tests * improvements * static methods * fix test * fixes for multiple tools + streaming w tools test --- integrations/google_vertex/pyproject.toml | 1 + .../generators/google_vertex/chat/gemini.py | 355 ++++++--- .../google_vertex/tests/chat/test_gemini.py | 749 +++++++++++------- 3 files changed, 720 insertions(+), 385 deletions(-) diff --git a/integrations/google_vertex/pyproject.toml b/integrations/google_vertex/pyproject.toml index d8b7b3408..3d88238e7 100644 --- a/integrations/google_vertex/pyproject.toml +++ b/integrations/google_vertex/pyproject.toml @@ -47,6 +47,7 @@ dependencies = [ "pytest", "pytest-rerunfailures", "haystack-pydoc-tools", + "jsonschema", # needed for Tool ] [tool.hatch.envs.default.scripts] test = "pytest {args:tests}" diff --git a/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/chat/gemini.py b/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/chat/gemini.py index 516116321..c8c9f22a0 100644 --- a/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/chat/gemini.py +++ b/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/chat/gemini.py @@ -5,25 +5,74 @@ from haystack.core.component import component from haystack.core.serialization import default_from_dict, default_to_dict from haystack.dataclasses import StreamingChunk -from haystack.dataclasses.byte_stream import ByteStream -from haystack.dataclasses.chat_message import ChatMessage, ChatRole +from haystack.dataclasses.chat_message import ChatMessage, ChatRole, ToolCall +from haystack.tools import Tool, _check_duplicate_tool_names, deserialize_tools_inplace from haystack.utils import deserialize_callable, serialize_callable from vertexai import init as vertexai_init from vertexai.generative_models import ( Content, + FunctionDeclaration, GenerationConfig, GenerationResponse, GenerativeModel, HarmBlockThreshold, HarmCategory, Part, - Tool, ToolConfig, ) +from vertexai.generative_models import Tool as VertexTool logger = logging.getLogger(__name__) +def _convert_chatmessage_to_google_content(message: ChatMessage) -> Content: + """ + Converts a Haystack `ChatMessage` to a Google `Content` object. + System messages are not supported. + + :param message: The Haystack `ChatMessage` to convert. + :returns: The Google `Content` object. + """ + + if message.is_from(ChatRole.SYSTEM): + msg = "This function does not support system messages." + raise ValueError(msg) + + texts = message.texts + tool_calls = message.tool_calls + tool_call_results = message.tool_call_results + + if not texts and not tool_calls and not tool_call_results: + msg = "A `ChatMessage` must contain at least one `TextContent`, `ToolCall`, or `ToolCallResult`." + raise ValueError(msg) + + if len(texts) + len(tool_call_results) > 1: + msg = "A `ChatMessage` can only contain one `TextContent` or one `ToolCallResult`." + raise ValueError(msg) + + role = "model" if message.is_from(ChatRole.ASSISTANT) else "user" + + if tool_call_results: + part = Part.from_function_response( + name=tool_call_results[0].origin.tool_name, response={"result": tool_call_results[0].result} + ) + return Content(parts=[part], role=role) + + parts = [Part.from_text(texts[0])] if texts else [] + for tc in tool_calls: + part = Part.from_dict( + { + "function_call": { + "name": tc.tool_name, + "args": tc.arguments, + } + } + ) + parts.append(part) + + return Content(parts=parts, role=role) + + @component class VertexAIGeminiChatGenerator: """ @@ -32,7 +81,7 @@ class VertexAIGeminiChatGenerator: Authenticates using Google Cloud Application Default Credentials (ADCs). For more information see the official [Google documentation](https://cloud.google.com/docs/authentication/provide-credentials-adc). - Usage example: + ### Usage example ```python from haystack.dataclasses import ChatMessage from haystack_integrations.components.generators.google_vertex import VertexAIGeminiChatGenerator @@ -44,6 +93,43 @@ class VertexAIGeminiChatGenerator: print(res["replies"][0].text) >>> The Shawshank Redemption + + #### With Tool calling: + + ```python + from typing import Annotated + from haystack.utils import Secret + from haystack.dataclasses.chat_message import ChatMessage + from haystack.components.tools import ToolInvoker + from haystack.tools import create_tool_from_function + + from haystack_integrations.components.generators.google_vertex import VertexAIGeminiChatGenerator + + # example function to get the current weather + def get_current_weather( + location: Annotated[str, "The city for which to get the weather, e.g. 'San Francisco'"] = "Munich", + unit: Annotated[str, "The unit for the temperature, e.g. 'celsius'"] = "celsius", + ) -> str: + return f"The weather in {location} is sunny. The temperature is 20 {unit}." + + tool = create_tool_from_function(get_current_weather) + tool_invoker = ToolInvoker(tools=[tool]) + + gemini_chat = VertexAIGeminiChatGenerator( + model="gemini-2.0-flash-exp", + tools=[tool], + ) + user_message = [ChatMessage.from_user("What is the temperature in celsius in Berlin?")] + replies = gemini_chat.run(messages=user_message)["replies"] + print(replies[0].tool_calls) + + # actually invoke the tool + tool_messages = tool_invoker.run(messages=replies)["tool_messages"] + messages = user_message + replies + tool_messages + + # transform the tool call result into a human readable message + final_replies = gemini_chat.run(messages=messages)["replies"] + print(final_replies[0].text) ``` """ @@ -57,7 +143,6 @@ def __init__( safety_settings: Optional[Dict[HarmCategory, HarmBlockThreshold]] = None, tools: Optional[List[Tool]] = None, tool_config: Optional[ToolConfig] = None, - system_instruction: Optional[Union[str, ByteStream, Part]] = None, streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, ): """ @@ -66,8 +151,8 @@ def __init__( Authenticates using Google Cloud Application Default Credentials (ADCs). For more information see the official [Google documentation](https://cloud.google.com/docs/authentication/provide-credentials-adc). - :param project_id: ID of the GCP project to use. By default, it is set during Google Cloud authentication. :param model: Name of the model to use. For available models, see https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models. + :param project_id: ID of the GCP project to use. By default, it is set during Google Cloud authentication. :param location: The default location to use when making API calls, if not set uses us-central-1. Defaults to None. :param generation_config: Configuration for the generation process. @@ -77,12 +162,10 @@ def __init__( for [HarmBlockThreshold](https://cloud.google.com/python/docs/reference/aiplatform/latest/vertexai.generative_models.HarmBlockThreshold) and [HarmCategory](https://cloud.google.com/python/docs/reference/aiplatform/latest/vertexai.generative_models.HarmCategory) for more details. - :param tools: List of tools to use when generating content. See the documentation for - [Tool](https://cloud.google.com/python/docs/reference/aiplatform/latest/vertexai.generative_models.Tool) - the list of supported arguments. + :param tools: + A list of tools for which the model can prepare calls. :param tool_config: The tool config to use. See the documentation for [ToolConfig] (https://cloud.google.com/vertex-ai/generative-ai/docs/reference/python/latest/vertexai.generative_models.ToolConfig) - :param system_instruction: Default system instruction to use for generating content. :param streaming_callback: A callback function that is called when a new token is received from the stream. The callback function accepts StreamingChunk as an argument. @@ -91,6 +174,8 @@ def __init__( # Login to GCP. This will fail if user has not set up their gcloud SDK vertexai_init(project=project_id, location=location) + _check_duplicate_tool_names(tools) + self._model_name = model self._project_id = project_id self._location = location @@ -100,32 +185,22 @@ def __init__( self._safety_settings = safety_settings self._tools = tools self._tool_config = tool_config - self._system_instruction = system_instruction self._streaming_callback = streaming_callback - # except streaming_callback, all other model parameters can be passed during initialization self._model = GenerativeModel( self._model_name, - generation_config=self._generation_config, - safety_settings=self._safety_settings, - tools=self._tools, tool_config=self._tool_config, - system_instruction=self._system_instruction, ) - def _generation_config_to_dict(self, config: Union[GenerationConfig, Dict[str, Any]]) -> Dict[str, Any]: + @staticmethod + def _generation_config_to_dict(config: Union[GenerationConfig, Dict[str, Any]]) -> Dict[str, Any]: + """Converts the GenerationConfig object to a dictionary.""" if isinstance(config, dict): return config - return { - "temperature": config._raw_generation_config.temperature, - "top_p": config._raw_generation_config.top_p, - "top_k": config._raw_generation_config.top_k, - "candidate_count": config._raw_generation_config.candidate_count, - "max_output_tokens": config._raw_generation_config.max_output_tokens, - "stop_sequences": config._raw_generation_config.stop_sequences, - } + return config.to_dict() - def _tool_config_to_dict(self, tool_config: ToolConfig) -> Dict[str, Any]: + @staticmethod + def _tool_config_to_dict(tool_config: ToolConfig) -> Dict[str, Any]: """Serializes the ToolConfig object into a dictionary.""" mode = tool_config._gapic_tool_config.function_calling_config.mode allowed_function_names = tool_config._gapic_tool_config.function_calling_config.allowed_function_names @@ -152,13 +227,10 @@ def to_dict(self) -> Dict[str, Any]: location=self._location, generation_config=self._generation_config, safety_settings=self._safety_settings, - tools=self._tools, + tools=[tool.to_dict() for tool in self._tools] if self._tools else None, tool_config=self._tool_config, - system_instruction=self._system_instruction, streaming_callback=callback_name, ) - if (tools := data["init_parameters"].get("tools")) is not None: - data["init_parameters"]["tools"] = [Tool.to_dict(t) for t in tools] if (tool_config := data["init_parameters"].get("tool_config")) is not None: data["init_parameters"]["tool_config"] = self._tool_config_to_dict(tool_config) if (generation_config := data["init_parameters"].get("generation_config")) is not None: @@ -186,8 +258,7 @@ def _tool_config_from_dict(config_dict: Dict[str, Any]) -> ToolConfig: ) ) - if (tools := data["init_parameters"].get("tools")) is not None: - data["init_parameters"]["tools"] = [Tool.from_dict(t) for t in tools] + deserialize_tools_inplace(data["init_parameters"], key="tools") if (generation_config := data["init_parameters"].get("generation_config")) is not None: data["init_parameters"]["generation_config"] = GenerationConfig.from_dict(generation_config) if (tool_config := data["init_parameters"].get("tool_config")) is not None: @@ -196,125 +267,175 @@ def _tool_config_from_dict(config_dict: Dict[str, Any]) -> ToolConfig: data["init_parameters"]["streaming_callback"] = deserialize_callable(serialized_callback_handler) return default_from_dict(cls, data) - def _convert_part(self, part: Union[str, ByteStream, Part]) -> Part: - if isinstance(part, str): - return Part.from_text(part) - elif isinstance(part, ByteStream): - return Part.from_data(part.data, part.mime_type) - elif isinstance(part, Part): - return part - else: - msg = f"Unsupported type {type(part)} for part {part}" - raise ValueError(msg) + @staticmethod + def _convert_to_vertex_tools(tools: List[Tool]) -> List[VertexTool]: + """ + Converts a list of Haystack `Tool` to a list of Vertex `Tool` objects. - def _message_to_part(self, message: ChatMessage) -> Part: - if message.role == ChatRole.ASSISTANT and message.name: - p = Part.from_dict({"function_call": {"name": message.name, "args": {}}}) - for k, v in json.loads(message.text).items(): - p.function_call.args[k] = v - return p - elif message.is_from(ChatRole.SYSTEM) or message.is_from(ChatRole.ASSISTANT): - return Part.from_text(message.text) - elif "FUNCTION" in ChatRole._member_names_ and message.is_from(ChatRole.FUNCTION): - return Part.from_function_response(name=message.name, response=message.text) - elif message.is_from(ChatRole.USER): - return self._convert_part(message.text) - - def _message_to_content(self, message: ChatMessage) -> Content: - if message.is_from(ChatRole.ASSISTANT) and message.name: - part = Part.from_dict({"function_call": {"name": message.name, "args": {}}}) - for k, v in json.loads(message.text).items(): - part.function_call.args[k] = v - elif message.is_from(ChatRole.SYSTEM) or message.is_from(ChatRole.ASSISTANT): - part = Part.from_text(message.text) - elif "FUNCTION" in ChatRole._member_names_ and message.is_from(ChatRole.FUNCTION): - part = Part.from_function_response(name=message.name, response=message.text) - elif message.is_from(ChatRole.USER): - part = self._convert_part(message.text) - else: - msg = f"Unsupported message role {message.role}" - raise ValueError(msg) + :param tools: The list of Haystack `Tool` to convert. + :returns: The list of Vertex `Tool` objects. + """ + function_declarations = [] - role = "model" if message.is_from(ChatRole.ASSISTANT) or message.is_from(ChatRole.SYSTEM) else "user" - return Content(parts=[part], role=role) + for tool in tools: + parameters = tool.parameters.copy() + + # Remove default values as Google API doesn't support them + for prop in parameters["properties"].values(): + prop.pop("default", None) + + function_declarations.append( + FunctionDeclaration(name=tool.name, description=tool.description, parameters=parameters) + ) + return [VertexTool(function_declarations=function_declarations)] @component.output_types(replies=List[ChatMessage]) def run( self, messages: List[ChatMessage], streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, + *, + tools: Optional[List[Tool]] = None, ): - """Prompts Google Vertex AI Gemini model to generate a response to a list of messages. - - :param messages: The last message is the prompt, the rest are the history. - :param streaming_callback: A callback function that is called when a new token is received from the stream. - :returns: A dictionary with the following keys: - - `replies`: A list of ChatMessage objects representing the model's replies. """ - # check if streaming_callback is passed + :param messages: + A list of `ChatMessage` instances, representing the input messages. + :param streaming_callback: + A callback function that is called when a new token is received from the stream. + :param tools: + A list of tools for which the model can prepare calls. If set, it will override the `tools` parameter set + during component initialization. + :returns: + A dictionary containing the following key: + - `replies`: A list containing the generated responses as `ChatMessage` instances. + """ streaming_callback = streaming_callback or self._streaming_callback - history = [self._message_to_content(m) for m in messages[:-1]] - session = self._model.start_chat(history=history) + tools = tools or self._tools + _check_duplicate_tool_names(tools) + google_tools = self._convert_to_vertex_tools(tools) if tools else None + + if messages[0].is_from(ChatRole.SYSTEM): + self._model._system_instruction = Part.from_text(messages[0].text) + messages = messages[1:] + + google_messages = [_convert_chatmessage_to_google_content(m) for m in messages] + + session = self._model.start_chat(history=google_messages[:-1]) + + candidate_count = 1 + if self._generation_config: + config_dict = self._generation_config_to_dict(self._generation_config) + candidate_count = config_dict.get("candidate_count", 1) + + if streaming_callback and candidate_count > 1: + msg = "Streaming is not supported with multiple candidates. Set candidate_count to 1." + raise ValueError(msg) - new_message = self._message_to_part(messages[-1]) res = session.send_message( - content=new_message, + content=google_messages[-1], + generation_config=self._generation_config, + safety_settings=self._safety_settings, stream=streaming_callback is not None, + tools=google_tools, ) - replies = self._get_stream_response(res, streaming_callback) if streaming_callback else self._get_response(res) + replies = ( + self._stream_response_and_convert_to_messages(res, streaming_callback) + if streaming_callback + else self._convert_response_to_messages(res) + ) return {"replies": replies} - def _get_response(self, response_body: GenerationResponse) -> List[ChatMessage]: + @staticmethod + def _convert_response_to_messages(response_body: GenerationResponse) -> List[ChatMessage]: """ - Extracts the responses from the Vertex AI response. + Converts the Google Vertex AI response to a list of `ChatMessage` instances. - :param response_body: The response from Vertex AI request. - :returns: The extracted responses. + :param response_body: The response from Google AI request. + :returns: List of `ChatMessage` instances. """ replies: List[ChatMessage] = [] + + usage_metadata = response_body.usage_metadata + openai_usage = { + "prompt_tokens": usage_metadata.prompt_token_count or 0, + "completion_tokens": usage_metadata.candidates_token_count or 0, + "total_tokens": usage_metadata.total_token_count or 0, + } + for candidate in response_body.candidates: - metadata = candidate.to_dict() + candidate_metadata = candidate.to_dict() + candidate_metadata.pop("content", None) + candidate_metadata["usage"] = openai_usage + + text = "" + tool_calls = [] for part in candidate.content.parts: - # Remove content from metadata - metadata.pop("content", None) - if part._raw_part.text != "": - replies.append(ChatMessage.from_assistant(part._raw_part.text, meta=metadata)) - elif part.function_call: - metadata["function_call"] = part.function_call - new_message = ChatMessage.from_assistant(json.dumps(dict(part.function_call.args)), meta=metadata) - new_message.name = part.function_call.name - replies.append(new_message) + # we need this strange check: calling part.text directly raises an error if the part has no text + if "text" in part._raw_part: + text += part.text + elif "function_call" in part._raw_part: + tool_calls.append( + ToolCall( + tool_name=part.function_call.name, + arguments=dict(part.function_call.args), + ) + ) + reply = ChatMessage.from_assistant(text=text, tool_calls=tool_calls, meta=candidate_metadata) + replies.append(reply) return replies - def _get_stream_response( + def _stream_response_and_convert_to_messages( self, stream: Iterable[GenerationResponse], streaming_callback: Callable[[StreamingChunk], None] ) -> List[ChatMessage]: """ - Extracts the responses from the Vertex AI streaming response. + Streams the Google Vertex AI response and converts it to a list of `ChatMessage` instances. - :param stream: The streaming response from the Vertex AI request. + :param stream: The streaming response from the Google AI request. :param streaming_callback: The handler for the streaming response. - :returns: The extracted response with the content of all streaming chunks. + :returns: List of `ChatMessage` instances. """ - replies: List[ChatMessage] = [] + + text = "" + tool_calls = [] + chunk_dict = {} for chunk in stream: - content: Union[str, Dict[str, Any]] = "" - metadata = chunk.to_dict() # we store whole chunk as metadata for streaming - for candidate in chunk.candidates: - for part in candidate.content.parts: - if part._raw_part.text: - content = chunk.text - replies.append(ChatMessage.from_assistant(content, meta=metadata)) - elif part.function_call: - metadata["function_call"] = part.function_call - content = json.dumps(dict(part.function_call.args)) - new_message = ChatMessage.from_assistant(content, meta=metadata) - new_message.name = part.function_call.name - replies.append(new_message) - streaming_callback(StreamingChunk(content=content, meta=metadata)) + content_to_stream = "" + chunk_dict = chunk.to_dict() + + # Only one candidate is supported with streaming + candidate = chunk_dict["candidates"][0] + + for part in candidate["content"]["parts"]: + if new_text := part.get("text"): + content_to_stream += new_text + text += new_text + elif new_function_call := part.get("function_call"): + content_to_stream += json.dumps(dict(new_function_call)) + tool_calls.append( + ToolCall( + tool_name=new_function_call["name"], + arguments=dict(new_function_call["args"]), + ) + ) + + streaming_callback(StreamingChunk(content=content_to_stream, meta=chunk_dict)) + + # store the last chunk metadata + meta = chunk_dict + + # format the usage metadata to be compatible with OpenAI + usage_metadata = meta.pop("usage_metadata", {}) + + openai_usage = { + "prompt_tokens": usage_metadata.get("prompt_token_count", 0), + "completion_tokens": usage_metadata.get("candidates_token_count", 0), + "total_tokens": usage_metadata.get("total_token_count", 0), + } - return replies + meta["usage"] = openai_usage + + return [ChatMessage.from_assistant(text=text or None, meta=meta, tool_calls=tool_calls)] diff --git a/integrations/google_vertex/tests/chat/test_gemini.py b/integrations/google_vertex/tests/chat/test_gemini.py index 614b83909..96f2bc9ff 100644 --- a/integrations/google_vertex/tests/chat/test_gemini.py +++ b/integrations/google_vertex/tests/chat/test_gemini.py @@ -1,262 +1,201 @@ +import json +from typing import Annotated, Literal from unittest.mock import MagicMock, Mock, patch import pytest from haystack import Pipeline -from haystack.components.builders import ChatPromptBuilder -from haystack.dataclasses import ChatMessage, ChatRole, StreamingChunk +from haystack.dataclasses import ChatMessage, ChatRole, StreamingChunk, TextContent, ToolCall +from haystack.tools import Tool, create_tool_from_function from vertexai.generative_models import ( Content, - FunctionDeclaration, GenerationConfig, GenerationResponse, HarmBlockThreshold, HarmCategory, Part, - Tool, ToolConfig, ) -from haystack_integrations.components.generators.google_vertex import VertexAIGeminiChatGenerator - -GET_CURRENT_WEATHER_FUNC = FunctionDeclaration( - name="get_current_weather", - description="Get the current weather in a given location", - parameters={ - "type": "object", - "properties": { - "location": {"type": "string", "description": "The city and state, e.g. San Francisco, CA"}, - "unit": { - "type": "string", - "enum": [ - "celsius", - "fahrenheit", - ], - }, - }, - "required": ["location"], - }, +from haystack_integrations.components.generators.google_vertex.chat.gemini import ( + VertexAIGeminiChatGenerator, + _convert_chatmessage_to_google_content, ) +def get_current_weather( + city: Annotated[str, "the city for which to get the weather, e.g. 'San Francisco'"] = "Munich", + unit: Annotated[Literal["Celsius", "Fahrenheit"], "the unit for the temperature"] = "Celsius", +): + """A simple function to get the current weather for a location.""" + return f"Weather report for {city}: 20 {unit}, sunny" + + @pytest.fixture -def chat_messages(): - return [ - ChatMessage.from_system("You are a helpful assistant"), - ChatMessage.from_user("What's the capital of France"), - ] - - -@patch("haystack_integrations.components.generators.google_vertex.chat.gemini.vertexai_init") -@patch("haystack_integrations.components.generators.google_vertex.chat.gemini.GenerativeModel") -def test_init(mock_vertexai_init, _mock_generative_model): - - generation_config = GenerationConfig( - candidate_count=1, - stop_sequences=["stop"], - max_output_tokens=10, - temperature=0.5, - top_p=0.5, - top_k=0.5, - ) - safety_settings = {HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH} +def tools(): + tool = create_tool_from_function(get_current_weather) + return [tool] - tool = Tool(function_declarations=[GET_CURRENT_WEATHER_FUNC]) - tool_config = ToolConfig( - function_calling_config=ToolConfig.FunctionCallingConfig( - mode=ToolConfig.FunctionCallingConfig.Mode.ANY, - allowed_function_names=["get_current_weather_func"], - ) - ) - gemini = VertexAIGeminiChatGenerator( - project_id="TestID123", - location="TestLocation", - generation_config=generation_config, - safety_settings=safety_settings, - tools=[tool], - tool_config=tool_config, - system_instruction="Please provide brief answers.", +def test_convert_chatmessage_to_google_content(): + chat_message = ChatMessage.from_assistant("Hello, how are you?") + google_content = _convert_chatmessage_to_google_content(chat_message) + + assert google_content.parts[0].text == "Hello, how are you?" + assert google_content.role == "model" + + message = ChatMessage.from_user("I have a question") + google_content = _convert_chatmessage_to_google_content(message) + assert google_content.parts[0].text == "I have a question" + assert google_content.role == "user" + + message = ChatMessage.from_assistant( + tool_calls=[ToolCall(id="123", tool_name="weather", arguments={"city": "Paris"})] ) - mock_vertexai_init.assert_called() - assert gemini._model_name == "gemini-1.5-flash" - assert gemini._generation_config == generation_config - assert gemini._safety_settings == safety_settings - assert gemini._tools == [tool] - assert gemini._tool_config == tool_config - assert gemini._system_instruction == "Please provide brief answers." - - -@patch("haystack_integrations.components.generators.google_vertex.chat.gemini.vertexai_init") -@patch("haystack_integrations.components.generators.google_vertex.chat.gemini.GenerativeModel") -def test_to_dict(_mock_vertexai_init, _mock_generative_model): - - gemini = VertexAIGeminiChatGenerator() - assert gemini.to_dict() == { - "type": "haystack_integrations.components.generators.google_vertex.chat.gemini.VertexAIGeminiChatGenerator", - "init_parameters": { - "model": "gemini-1.5-flash", - "project_id": None, - "location": None, - "generation_config": None, - "safety_settings": None, - "streaming_callback": None, - "tools": None, - "tool_config": None, - "system_instruction": None, - }, - } - - -@patch("haystack_integrations.components.generators.google_vertex.chat.gemini.vertexai_init") -@patch("haystack_integrations.components.generators.google_vertex.chat.gemini.GenerativeModel") -def test_to_dict_with_params(_mock_vertexai_init, _mock_generative_model): - generation_config = GenerationConfig( - candidate_count=1, - stop_sequences=["stop"], - max_output_tokens=10, - temperature=0.5, - top_p=0.5, - top_k=2, + google_content = _convert_chatmessage_to_google_content(message) + assert google_content.parts[0].function_call.name == "weather" + assert google_content.parts[0].function_call.args == {"city": "Paris"} + assert google_content.role == "model" + + tool_result = json.dumps({"weather": "sunny", "temperature": "25"}) + message = ChatMessage.from_tool( + tool_result=tool_result, origin=ToolCall(id="123", tool_name="weather", arguments={"city": "Paris"}) ) - safety_settings = {HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH} + google_content = _convert_chatmessage_to_google_content(message) + assert google_content.parts[0].function_response.name == "weather" + assert google_content.parts[0].function_response.response == {"result": tool_result} + assert google_content.role == "user" - tool = Tool(function_declarations=[GET_CURRENT_WEATHER_FUNC]) - tool_config = ToolConfig( - function_calling_config=ToolConfig.FunctionCallingConfig( - mode=ToolConfig.FunctionCallingConfig.Mode.ANY, - allowed_function_names=["get_current_weather_func"], - ) - ) - gemini = VertexAIGeminiChatGenerator( - project_id="TestID123", - location="TestLocation", - generation_config=generation_config, - safety_settings=safety_settings, - tools=[tool], - tool_config=tool_config, - system_instruction="Please provide brief answers.", +def test_convert_chatmessage_to_google_content_invalid(): + message = ChatMessage(_role=ChatRole.ASSISTANT, _content=[]) + with pytest.raises(ValueError): + _convert_chatmessage_to_google_content(message) + + message = ChatMessage( + _role=ChatRole.ASSISTANT, + _content=[TextContent(text="I have an answer"), TextContent(text="I have another answer")], ) + with pytest.raises(ValueError): + _convert_chatmessage_to_google_content(message) - assert gemini.to_dict() == { - "type": "haystack_integrations.components.generators.google_vertex.chat.gemini.VertexAIGeminiChatGenerator", - "init_parameters": { - "model": "gemini-1.5-flash", - "project_id": "TestID123", - "location": "TestLocation", - "generation_config": { - "temperature": 0.5, - "top_p": 0.5, - "top_k": 2.0, - "candidate_count": 1, - "max_output_tokens": 10, - "stop_sequences": ["stop"], - }, - "safety_settings": {HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH}, - "streaming_callback": None, - "tools": [ - { - "function_declarations": [ - { - "name": "get_current_weather", - "description": "Get the current weather in a given location", - "parameters": { - "type": "OBJECT", - "properties": { - "location": { - "type": "STRING", - "description": "The city and state, e.g. San Francisco, CA", - }, - "unit": {"type": "STRING", "enum": ["celsius", "fahrenheit"]}, - }, - "required": ["location"], - "property_ordering": ["location", "unit"], - }, - } - ] - } - ], - "tool_config": { - "function_calling_config": { - "mode": ToolConfig.FunctionCallingConfig.Mode.ANY, - "allowed_function_names": ["get_current_weather_func"], - } - }, - "system_instruction": "Please provide brief answers.", - }, - } + message = ChatMessage.from_system("You are a helpful assistant.") + with pytest.raises(ValueError): + _convert_chatmessage_to_google_content(message) + + +class TestVertexAIGeminiChatGenerator: + + @patch("haystack_integrations.components.generators.google_vertex.chat.gemini.vertexai_init") + @patch("haystack_integrations.components.generators.google_vertex.chat.gemini.GenerativeModel") + def test_init(self, mock_vertexai_init, _mock_generative_model, tools): + generation_config = GenerationConfig( + candidate_count=1, + stop_sequences=["stop"], + max_output_tokens=10, + temperature=0.5, + top_p=0.5, + top_k=0.5, + ) + safety_settings = {HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH} + + tool_config = ToolConfig( + function_calling_config=ToolConfig.FunctionCallingConfig( + mode=ToolConfig.FunctionCallingConfig.Mode.ANY, + allowed_function_names=["get_current_weather_func"], + ) + ) -@patch("haystack_integrations.components.generators.google_vertex.chat.gemini.vertexai_init") -@patch("haystack_integrations.components.generators.google_vertex.chat.gemini.GenerativeModel") -def test_from_dict(_mock_vertexai_init, _mock_generative_model): - gemini = VertexAIGeminiChatGenerator.from_dict( - { + gemini = VertexAIGeminiChatGenerator( + project_id="TestID123", + location="TestLocation", + generation_config=generation_config, + safety_settings=safety_settings, + tools=tools, + tool_config=tool_config, + ) + mock_vertexai_init.assert_called() + assert gemini._model_name == "gemini-1.5-flash" + assert gemini._generation_config == generation_config + assert gemini._safety_settings == safety_settings + assert gemini._tools == tools + assert gemini._tool_config == tool_config + + @patch("haystack_integrations.components.generators.google_vertex.chat.gemini.vertexai_init") + @patch("haystack_integrations.components.generators.google_vertex.chat.gemini.GenerativeModel") + def test_to_dict(self, _mock_vertexai_init, _mock_generative_model): + + gemini = VertexAIGeminiChatGenerator() + assert gemini.to_dict() == { "type": "haystack_integrations.components.generators.google_vertex.chat.gemini.VertexAIGeminiChatGenerator", "init_parameters": { - "project_id": None, "model": "gemini-1.5-flash", + "project_id": None, + "location": None, "generation_config": None, "safety_settings": None, - "tools": None, "streaming_callback": None, + "tools": None, + "tool_config": None, }, } - ) - assert gemini._model_name == "gemini-1.5-flash" - assert gemini._project_id is None - assert gemini._safety_settings is None - assert gemini._tools is None - assert gemini._tool_config is None - assert gemini._system_instruction is None - assert gemini._generation_config is None + @patch("haystack_integrations.components.generators.google_vertex.chat.gemini.vertexai_init") + @patch("haystack_integrations.components.generators.google_vertex.chat.gemini.GenerativeModel") + def test_to_dict_with_params(self, _mock_vertexai_init, _mock_generative_model): + tools = [Tool(name="name", description="description", parameters={"x": {"type": "string"}}, function=print)] + + generation_config = GenerationConfig( + candidate_count=1, + stop_sequences=["stop"], + max_output_tokens=10, + temperature=0.5, + top_p=0.5, + top_k=2, + ) + safety_settings = {HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH} + tool_config = ToolConfig( + function_calling_config=ToolConfig.FunctionCallingConfig( + mode=ToolConfig.FunctionCallingConfig.Mode.ANY, + allowed_function_names=["get_current_weather_func"], + ) + ) + + gemini = VertexAIGeminiChatGenerator( + project_id="TestID123", + location="TestLocation", + generation_config=generation_config, + safety_settings=safety_settings, + tools=tools, + tool_config=tool_config, + ) -@patch("haystack_integrations.components.generators.google_vertex.chat.gemini.vertexai_init") -@patch("haystack_integrations.components.generators.google_vertex.chat.gemini.GenerativeModel") -def test_from_dict_with_param(_mock_vertexai_init, _mock_generative_model): - gemini = VertexAIGeminiChatGenerator.from_dict( - { + assert gemini.to_dict() == { "type": "haystack_integrations.components.generators.google_vertex.chat.gemini.VertexAIGeminiChatGenerator", "init_parameters": { + "model": "gemini-1.5-flash", "project_id": "TestID123", "location": "TestLocation", - "model": "gemini-1.5-flash", "generation_config": { "temperature": 0.5, "top_p": 0.5, - "top_k": 0.5, + "top_k": 2.0, "candidate_count": 1, "max_output_tokens": 10, "stop_sequences": ["stop"], }, "safety_settings": {HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH}, + "streaming_callback": None, "tools": [ { - "function_declarations": [ - { - "name": "get_current_weather", - "description": "Get the current weather in a given location", - "parameters": { - "type": "object", - "properties": { - "location": { - "type": "string", - "description": "The city and state, e.g. San Francisco, CA", - }, - "unit": { - "type": "string", - "enum": [ - "celsius", - "fahrenheit", - ], - }, - }, - "required": ["location"], - }, - } - ] + "type": "haystack.tools.tool.Tool", + "data": { + "description": "description", + "function": "builtins.print", + "name": "name", + "parameters": {"x": {"type": "string"}}, + }, } ], "tool_config": { @@ -265,83 +204,357 @@ def test_from_dict_with_param(_mock_vertexai_init, _mock_generative_model): "allowed_function_names": ["get_current_weather_func"], } }, - "system_instruction": "Please provide brief answers.", - "streaming_callback": None, }, } - ) - assert gemini._model_name == "gemini-1.5-flash" - assert gemini._project_id == "TestID123" - assert gemini._location == "TestLocation" - assert gemini._safety_settings == {HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH} - assert repr(gemini._tools) == repr([Tool(function_declarations=[GET_CURRENT_WEATHER_FUNC])]) - assert isinstance(gemini._tool_config, ToolConfig) - assert isinstance(gemini._generation_config, GenerationConfig) - assert gemini._system_instruction == "Please provide brief answers." - assert ( - gemini._tool_config._gapic_tool_config.function_calling_config.mode == ToolConfig.FunctionCallingConfig.Mode.ANY - ) + @patch("haystack_integrations.components.generators.google_vertex.chat.gemini.vertexai_init") + @patch("haystack_integrations.components.generators.google_vertex.chat.gemini.GenerativeModel") + def test_from_dict(self, _mock_vertexai_init, _mock_generative_model): + gemini = VertexAIGeminiChatGenerator.from_dict( + { + "type": ( + "haystack_integrations.components.generators.google_vertex.chat.gemini." + "VertexAIGeminiChatGenerator" + ), + "init_parameters": { + "project_id": None, + "model": "gemini-1.5-flash", + "generation_config": None, + "safety_settings": None, + "tools": None, + "streaming_callback": None, + }, + } + ) + assert gemini._model_name == "gemini-1.5-flash" + assert gemini._project_id is None + assert gemini._safety_settings is None + assert gemini._tools is None + assert gemini._tool_config is None + assert gemini._generation_config is None + + @patch("haystack_integrations.components.generators.google_vertex.chat.gemini.vertexai_init") + @patch("haystack_integrations.components.generators.google_vertex.chat.gemini.GenerativeModel") + def test_from_dict_with_param(self, _mock_vertexai_init, _mock_generative_model): + tools = [Tool(name="name", description="description", parameters={"x": {"type": "string"}}, function=print)] + + gemini = VertexAIGeminiChatGenerator.from_dict( + { + "type": ( + "haystack_integrations.components.generators.google_vertex.chat.gemini." + "VertexAIGeminiChatGenerator" + ), + "init_parameters": { + "project_id": "TestID123", + "location": "TestLocation", + "model": "gemini-1.5-flash", + "generation_config": { + "temperature": 0.5, + "top_p": 0.5, + "top_k": 0.5, + "candidate_count": 1, + "max_output_tokens": 10, + "stop_sequences": ["stop"], + }, + "safety_settings": { + HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH + }, + "tools": [ + { + "type": "haystack.tools.tool.Tool", + "data": { + "description": "description", + "function": "builtins.print", + "name": "name", + "parameters": {"x": {"type": "string"}}, + }, + } + ], + "tool_config": { + "function_calling_config": { + "mode": ToolConfig.FunctionCallingConfig.Mode.ANY, + "allowed_function_names": ["get_current_weather_func"], + } + }, + "streaming_callback": None, + }, + } + ) -@patch("haystack_integrations.components.generators.google_vertex.chat.gemini.GenerativeModel") -def test_run(mock_generative_model): - mock_model = Mock() - mock_candidate = Mock(content=Content(parts=[Part.from_text("This is a generated response.")], role="model")) - mock_response = MagicMock(spec=GenerationResponse, candidates=[mock_candidate]) + assert gemini._model_name == "gemini-1.5-flash" + assert gemini._project_id == "TestID123" + assert gemini._location == "TestLocation" + assert gemini._safety_settings == { + HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH + } + assert gemini._tools == tools + assert isinstance(gemini._tool_config, ToolConfig) + assert isinstance(gemini._generation_config, GenerationConfig) + assert ( + gemini._tool_config._gapic_tool_config.function_calling_config.mode + == ToolConfig.FunctionCallingConfig.Mode.ANY + ) - mock_model.send_message.return_value = mock_response - mock_model.start_chat.return_value = mock_model - mock_generative_model.return_value = mock_model + def test_convert_to_vertex_tools(self, tools): + vertex_tools = VertexAIGeminiChatGenerator._convert_to_vertex_tools(tools) - messages = [ - ChatMessage.from_system("You are a helpful assistant"), - ChatMessage.from_user("What's the capital of France?"), - ] - gemini = VertexAIGeminiChatGenerator() - response = gemini.run(messages=messages) + function_declaration = vertex_tools[0]._raw_tool.function_declarations[0] + assert function_declaration.name == tools[0].name + assert function_declaration.description == tools[0].description - mock_model.send_message.assert_called_once() - assert "replies" in response - assert len(response["replies"]) > 0 - assert all(reply.role == ChatRole.ASSISTANT for reply in response["replies"]) + assert function_declaration.parameters + # check if default values are removed. This type is not easily inspectable + assert "default" not in str(function_declaration.parameters) -@patch("haystack_integrations.components.generators.google_vertex.chat.gemini.GenerativeModel") -def test_run_with_streaming_callback(mock_generative_model): - mock_model = Mock() - mock_responses = iter( - [MagicMock(spec=GenerationResponse, text="First part"), MagicMock(spec=GenerationResponse, text="Second part")] - ) - mock_model.send_message.return_value = mock_responses - mock_model.start_chat.return_value = mock_model - mock_generative_model.return_value = mock_model + @patch("haystack_integrations.components.generators.google_vertex.chat.gemini.GenerativeModel") + def test_run(self, mock_generative_model): + mock_model = Mock() + mock_candidate = MagicMock( + content=Content(parts=[Part.from_text("This is a generated response.")], role="model") + ) + mock_response = MagicMock(spec=GenerationResponse, candidates=[mock_candidate]) + + mock_model.send_message.return_value = mock_response + mock_model.start_chat.return_value = mock_model + mock_generative_model.return_value = mock_model + + messages = [ + ChatMessage.from_system("You are a helpful assistant"), + ChatMessage.from_user("What's the capital of France?"), + ] + gemini = VertexAIGeminiChatGenerator() + response = gemini.run(messages=messages) + + mock_model.send_message.assert_called_once() + assert "replies" in response + reply = response["replies"][0] + assert reply.role == ChatRole.ASSISTANT + assert reply.text == "This is a generated response." + + @patch("haystack_integrations.components.generators.google_vertex.chat.gemini.GenerativeModel") + def test_run_with_streaming_callback(self, mock_generative_model): + mock_model = Mock() + + mock_responses = [ + MagicMock( + spec=GenerationResponse, + to_dict=lambda: { + "candidates": [{"content": {"parts": [{"text": "First part "}]}}], + }, + ), + MagicMock( + spec=GenerationResponse, + to_dict=lambda: { + "candidates": [{"content": {"parts": [{"text": "Second part"}]}}], + "usage_metadata": {"prompt_token_count": 10, "candidates_token_count": 5, "total_token_count": 15}, + }, + ), + ] + + mock_model.send_message.return_value = iter(mock_responses) + mock_model.start_chat.return_value = mock_model + mock_generative_model.return_value = mock_model + + received_chunks = [] + + def streaming_callback(chunk: StreamingChunk) -> None: + received_chunks.append(chunk) + + gemini = VertexAIGeminiChatGenerator(streaming_callback=streaming_callback) + messages = [ + ChatMessage.from_system("You are a helpful assistant"), + ChatMessage.from_user("What's the capital of France?"), + ] + + response = gemini.run(messages=messages) + + assert len(received_chunks) == 2 + assert received_chunks[0].content == "First part " + assert received_chunks[1].content == "Second part" + + assert "replies" in response + reply = response["replies"][0] + assert reply.role == ChatRole.ASSISTANT + assert reply.text == "First part Second part" + + assert reply.meta["usage"] == {"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15} + + @patch("haystack_integrations.components.generators.google_vertex.chat.gemini.GenerativeModel") + def test_run_with_tools(self, mock_generative_model, tools): + mock_model = Mock() + mock_candidate = MagicMock( + content=Content( + parts=[ + Part.from_dict( + {"function_call": {"name": "get_current_weather", "args": {"city": "Paris", "unit": "Celsius"}}} + ), + ], + role="model", + ) + ) + mock_response = MagicMock(spec=GenerationResponse, candidates=[mock_candidate]) + + mock_model.send_message.return_value = mock_response + mock_model.start_chat.return_value = mock_model + mock_generative_model.return_value = mock_model + + messages = [ + ChatMessage.from_user("What's the weather in Paris?"), + ] + + gemini = VertexAIGeminiChatGenerator(tools=tools) + response = gemini.run(messages=messages) + + mock_model.send_message.assert_called_once() + call_kwargs = mock_model.send_message.call_args.kwargs + assert "tools" in call_kwargs + + assert "replies" in response + reply = response["replies"][0] + assert reply.role == ChatRole.ASSISTANT + assert not reply.text + assert len(reply.tool_calls) == 1 + assert reply.tool_calls[0].tool_name == "get_current_weather" + assert reply.tool_calls[0].arguments == {"city": "Paris", "unit": "Celsius"} + + @patch("haystack_integrations.components.generators.google_vertex.chat.gemini.GenerativeModel") + def test_run_with_muliple_tools_and_streaming(self, mock_generative_model, tools): + """ + Test that the generator can handle multiple tools and streaming. + Note: this test case is made up because in practice I have always seen multiple function calls in a single + streaming chunk. + """ + + def population(city: Annotated[str, "the city for which to get the population, e.g. 'Munich'"] = "Munich"): + """A simple function to get the population for a location.""" + return f"Population of {city}: 1,000,000" + + multiple_tools = [tools[0], create_tool_from_function(population)] + + mock_model = Mock() + + mock_responses = [ + MagicMock( + spec=GenerationResponse, + to_dict=lambda: { + "candidates": [ + { + "content": { + "parts": [ + { + "function_call": { + "name": "get_current_weather", + "args": {"city": "Munich", "unit": "Farenheit"}, + } + } + ] + } + } + ] + }, + ), + MagicMock( + spec=GenerationResponse, + to_dict=lambda: { + "candidates": [ + {"content": {"parts": [{"function_call": {"name": "population", "args": {"city": "Munich"}}}]}} + ], + "usage_metadata": {"prompt_token_count": 10, "candidates_token_count": 5, "total_token_count": 15}, + }, + ), + ] - streaming_callback_called = [] + mock_model.send_message.return_value = iter(mock_responses) + mock_model.start_chat.return_value = mock_model + mock_generative_model.return_value = mock_model - def streaming_callback(_chunk: StreamingChunk) -> None: - nonlocal streaming_callback_called - streaming_callback_called = True + received_chunks = [] - gemini = VertexAIGeminiChatGenerator(streaming_callback=streaming_callback) - messages = [ - ChatMessage.from_system("You are a helpful assistant"), - ChatMessage.from_user("What's the capital of France?"), - ] - response = gemini.run(messages=messages) - mock_model.send_message.assert_called_once() - assert "replies" in response + def streaming_callback(chunk: StreamingChunk) -> None: + received_chunks.append(chunk) + messages = [ + ChatMessage.from_user("What's the weather in Munich (in Farenheit) and how many people live there?"), + ] -def test_serialization_deserialization_pipeline(): + gemini = VertexAIGeminiChatGenerator(tools=multiple_tools, streaming_callback=streaming_callback) + response = gemini.run(messages=messages) - pipeline = Pipeline() - template = [ChatMessage.from_user("Translate to {{ target_language }}. Context: {{ snippet }}; Translation:")] - pipeline.add_component("prompt_builder", ChatPromptBuilder(template=template)) - pipeline.add_component("gemini", VertexAIGeminiChatGenerator(project_id="TestID123")) - pipeline.connect("prompt_builder.prompt", "gemini.messages") + assert len(received_chunks) == 2 + assert json.loads(received_chunks[0].content) == { + "name": "get_current_weather", + "args": {"city": "Munich", "unit": "Farenheit"}, + } + assert json.loads(received_chunks[1].content) == {"name": "population", "args": {"city": "Munich"}} + + assert "replies" in response + reply = response["replies"][0] + assert reply.role == ChatRole.ASSISTANT + assert not reply.text + assert len(reply.tool_calls) == 2 + assert reply.tool_calls[0].tool_name == "get_current_weather" + assert reply.tool_calls[0].arguments == {"city": "Munich", "unit": "Farenheit"} + assert reply.tool_calls[1].tool_name == "population" + assert reply.tool_calls[1].arguments == {"city": "Munich"} + assert reply.meta["usage"] == {"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15} + + def test_serde_in_pipeline(self): + tool = Tool(name="name", description="description", parameters={"x": {"type": "string"}}, function=print) + + generator = VertexAIGeminiChatGenerator( + project_id="TestID123", + model="gemini-1.5-flash", + generation_config=GenerationConfig( + temperature=0.6, + stop_sequences=["stop", "words"], + ), + tools=[tool], + ) + + pipeline = Pipeline() + pipeline.add_component("generator", generator) + + pipeline_dict = pipeline.to_dict() + assert pipeline_dict == { + "metadata": {}, + "max_runs_per_component": 100, + "components": { + "generator": { + "type": ( + "haystack_integrations.components.generators.google_vertex.chat.gemini." + "VertexAIGeminiChatGenerator" + ), + "init_parameters": { + "project_id": "TestID123", + "model": "gemini-1.5-flash", + "generation_config": { + "temperature": 0.6, + "stop_sequences": ["stop", "words"], + }, + "location": None, + "safety_settings": None, + "streaming_callback": None, + "tool_config": None, + "tools": [ + { + "type": "haystack.tools.tool.Tool", + "data": { + "name": "name", + "description": "description", + "parameters": {"x": {"type": "string"}}, + "function": "builtins.print", + }, + } + ], + }, + } + }, + "connections": [], + } - pipeline_dict = pipeline.to_dict() + pipeline_yaml = pipeline.dumps() - new_pipeline = Pipeline.from_dict(pipeline_dict) - assert new_pipeline == pipeline + new_pipeline = Pipeline.loads(pipeline_yaml) + assert new_pipeline == pipeline From ae18ca60d58ced03b9f0b36a0450cbca287588fb Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Wed, 29 Jan 2025 15:03:15 +0100 Subject: [PATCH 225/229] Fix failing automatic PR generation for changelog (#1330) --- .github/workflows/CI_pypi_release.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/CI_pypi_release.yml b/.github/workflows/CI_pypi_release.yml index 4f4080619..f4ebba9dd 100644 --- a/.github/workflows/CI_pypi_release.yml +++ b/.github/workflows/CI_pypi_release.yml @@ -63,6 +63,7 @@ jobs: token: ${{ secrets.HAYSTACK_BOT_TOKEN }} commit-message: "Update changelog for ${{ steps.pathfinder.outputs.project_path }}" branch: update-changelog-${{ steps.pathfinder.outputs.project_path }} + base: main title: "docs: update changelog for ${{ steps.pathfinder.outputs.project_path }}" add-paths: | ${{ steps.pathfinder.outputs.project_path }}/CHANGELOG.md From 232c537c60ec67616f230cb1e57bd110b1400ede Mon Sep 17 00:00:00 2001 From: Haystack Bot <73523382+HaystackBot@users.noreply.github.com> Date: Wed, 29 Jan 2025 15:07:33 +0100 Subject: [PATCH 226/229] Update changelog for integrations/google_vertex (#1331) Co-authored-by: anakin87 <44616784+anakin87@users.noreply.github.com> --- integrations/google_vertex/CHANGELOG.md | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/integrations/google_vertex/CHANGELOG.md b/integrations/google_vertex/CHANGELOG.md index a544507a1..9cb3e4575 100644 --- a/integrations/google_vertex/CHANGELOG.md +++ b/integrations/google_vertex/CHANGELOG.md @@ -1,6 +1,10 @@ # Changelog -## [unreleased] +## [integrations/google_vertex-v5.0.0] - 2025-01-29 + +### 🚀 Features + +- [**breaking**] Google Vertex - support for Tool + general refactoring (#1327) ### 🌀 Miscellaneous From f08c2640045717db76536206e653d99fd8d02d38 Mon Sep 17 00:00:00 2001 From: Amna Mubashar Date: Wed, 29 Jan 2025 15:27:09 +0100 Subject: [PATCH 227/229] feat: add `response_format` param to `OllamaChatGenerator` (#1326) * Add response_ format param to Ollama integration * Add related tests --- integrations/ollama/pyproject.toml | 4 +- .../generators/ollama/chat/chat_generator.py | 21 +++++- .../ollama/tests/test_chat_generator.py | 71 +++++++++++++++++++ 3 files changed, 92 insertions(+), 4 deletions(-) diff --git a/integrations/ollama/pyproject.toml b/integrations/ollama/pyproject.toml index 65895e636..8c02f360b 100644 --- a/integrations/ollama/pyproject.toml +++ b/integrations/ollama/pyproject.toml @@ -26,7 +26,7 @@ classifiers = [ "Programming Language :: Python :: Implementation :: CPython", "Programming Language :: Python :: Implementation :: PyPy", ] -dependencies = ["haystack-ai", "ollama>=0.4.0"] +dependencies = ["haystack-ai", "ollama>=0.4.0", "pydantic"] [project.urls] Documentation = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/ollama#readme" @@ -165,5 +165,5 @@ markers = [ addopts = ["--import-mode=importlib"] [[tool.mypy.overrides]] -module = ["haystack.*", "haystack_integrations.*", "pytest.*", "ollama.*"] +module = ["haystack.*", "haystack_integrations.*", "pytest.*", "ollama.*", "pydantic.*"] ignore_missing_imports = true diff --git a/integrations/ollama/src/haystack_integrations/components/generators/ollama/chat/chat_generator.py b/integrations/ollama/src/haystack_integrations/components/generators/ollama/chat/chat_generator.py index c2112d3d6..fd4d6e6af 100644 --- a/integrations/ollama/src/haystack_integrations/components/generators/ollama/chat/chat_generator.py +++ b/integrations/ollama/src/haystack_integrations/components/generators/ollama/chat/chat_generator.py @@ -1,9 +1,10 @@ -from typing import Any, Callable, Dict, List, Optional, Union +from typing import Any, Callable, Dict, List, Literal, Optional, Union from haystack import component, default_from_dict, default_to_dict from haystack.dataclasses import ChatMessage, StreamingChunk, ToolCall from haystack.tools import Tool, _check_duplicate_tool_names, deserialize_tools_inplace from haystack.utils.callable_serialization import deserialize_callable, serialize_callable +from pydantic.json_schema import JsonSchemaValue from ollama import ChatResponse, Client @@ -97,6 +98,7 @@ def __init__( keep_alive: Optional[Union[float, str]] = None, streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, tools: Optional[List[Tool]] = None, + response_format: Optional[Union[None, Literal["json"], JsonSchemaValue]] = None, ): """ :param model: @@ -124,6 +126,11 @@ def __init__( A list of tools for which the model can prepare calls. Not all models support tools. For a list of models compatible with tools, see the [models page](https://ollama.com/search?c=tools). + :param response_format: + The format for structured model outputs. The value can be: + - None: No specific structure or format is applied to the response. The response is returned as-is. + - "json": The response is formatted as a JSON object. + - JSON Schema: The response is formatted as a JSON object that adheres to the specified JSON Schema. """ _check_duplicate_tool_names(tools) @@ -135,7 +142,7 @@ def __init__( self.keep_alive = keep_alive self.streaming_callback = streaming_callback self.tools = tools - + self.response_format = response_format self._client = Client(host=self.url, timeout=self.timeout) def to_dict(self) -> Dict[str, Any]: @@ -156,6 +163,7 @@ def to_dict(self) -> Dict[str, Any]: timeout=self.timeout, streaming_callback=callback_name, tools=serialized_tools, + response_format=self.response_format, ) @classmethod @@ -237,6 +245,14 @@ def run( msg = "Ollama does not support tools and streaming at the same time. Please choose one." raise ValueError(msg) + if self.response_format and tools: + msg = "Ollama does not support tools and response_format at the same time. Please choose one." + raise ValueError(msg) + + if self.response_format and stream: + msg = "Ollama does not support streaming and response_format at the same time. Please choose one." + raise ValueError(msg) + ollama_tools = [{"type": "function", "function": {**t.tool_spec}} for t in tools] if tools else None ollama_messages = [_convert_chatmessage_to_ollama_format(msg) for msg in messages] @@ -247,6 +263,7 @@ def run( stream=stream, keep_alive=self.keep_alive, options=generation_kwargs, + format=self.response_format, ) if stream: diff --git a/integrations/ollama/tests/test_chat_generator.py b/integrations/ollama/tests/test_chat_generator.py index cb357027a..ea2c2035e 100644 --- a/integrations/ollama/tests/test_chat_generator.py +++ b/integrations/ollama/tests/test_chat_generator.py @@ -165,6 +165,7 @@ def test_init_default(self): assert component.streaming_callback is None assert component.tools is None assert component.keep_alive is None + assert component.response_format is None def test_init(self, tools): component = OllamaChatGenerator( @@ -175,6 +176,7 @@ def test_init(self, tools): keep_alive="10m", streaming_callback=print_streaming_chunk, tools=tools, + response_format={"type": "object", "properties": {"name": {"type": "string"}, "age": {"type": "number"}}}, ) assert component.model == "llama2" @@ -184,6 +186,10 @@ def test_init(self, tools): assert component.keep_alive == "10m" assert component.streaming_callback is print_streaming_chunk assert component.tools == tools + assert component.response_format == { + "type": "object", + "properties": {"name": {"type": "string"}, "age": {"type": "number"}}, + } def test_init_fail_with_duplicate_tool_names(self, tools): @@ -206,6 +212,7 @@ def test_to_dict(self): generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"}, tools=[tool], keep_alive="5m", + response_format={"type": "object", "properties": {"name": {"type": "string"}, "age": {"type": "number"}}}, ) data = component.to_dict() assert data == { @@ -235,6 +242,10 @@ def test_to_dict(self): }, }, ], + "response_format": { + "type": "object", + "properties": {"name": {"type": "string"}, "age": {"type": "number"}}, + }, }, } @@ -273,6 +284,10 @@ def test_from_dict(self): }, }, ], + "response_format": { + "type": "object", + "properties": {"name": {"type": "string"}, "age": {"type": "number"}}, + }, }, } component = OllamaChatGenerator.from_dict(data) @@ -286,6 +301,10 @@ def test_from_dict(self): } assert component.timeout == 120 assert component.tools == [tool] + assert component.response_format == { + "type": "object", + "properties": {"name": {"type": "string"}, "age": {"type": "number"}}, + } @patch("haystack_integrations.components.generators.ollama.chat.chat_generator.Client") def test_run(self, mock_client): @@ -319,6 +338,7 @@ def test_run(self, mock_client): tools=None, options={}, keep_alive=None, + format=None, ) assert "replies" in result @@ -456,3 +476,54 @@ def test_run_with_tools(self, tools): assert isinstance(tool_call, ToolCall) assert tool_call.tool_name == "weather" assert tool_call.arguments == {"city": "Paris"} + + @pytest.mark.integration + def test_run_with_response_format(self): + response_format = { + "type": "object", + "properties": {"capital": {"type": "string"}, "population": {"type": "number"}}, + } + chat_generator = OllamaChatGenerator(model="llama3.2:3b", response_format=response_format) + + message = ChatMessage.from_user("What's the capital of France and its population?") + response = chat_generator.run([message]) + + assert isinstance(response, dict) + assert isinstance(response["replies"], list) + + # Parse the response text as JSON and verify its structure + response_data = json.loads(response["replies"][0].text) + assert isinstance(response_data, dict) + assert "capital" in response_data + assert isinstance(response_data["capital"], str) + assert "population" in response_data + assert isinstance(response_data["population"], (int, float)) + assert response_data["capital"] == "Paris" + + def test_run_with_streaming_and_format(self): + response_format = { + "type": "object", + "properties": {"answer": {"type": "string"}}, + } + streaming_callback = Mock() + chat_generator = OllamaChatGenerator( + model="llama3.2:3b", streaming_callback=streaming_callback, response_format=response_format + ) + + chat_messages = [ + ChatMessage.from_user("What is the largest city in the United Kingdom by population?"), + ChatMessage.from_assistant("London is the largest city in the United Kingdom by population"), + ChatMessage.from_user("And what is the second largest?"), + ] + with pytest.raises(ValueError): + chat_generator.run([chat_messages]) + + def test_run_with_tools_and_format(self, tools): + response_format = { + "type": "object", + "properties": {"capital": {"type": "string"}, "population": {"type": "number"}}, + } + chat_generator = OllamaChatGenerator(model="llama3.2:3b", tools=tools, response_format=response_format) + message = ChatMessage.from_user("What's the weather in Paris?") + with pytest.raises(ValueError): + chat_generator.run([message]) From ea4f2d6e5b88194d0635104f6873956d8bcec303 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 29 Jan 2025 15:39:14 +0100 Subject: [PATCH 228/229] chore(deps): bump aws-actions/configure-aws-credentials (#1324) Bumps [aws-actions/configure-aws-credentials](https://github.com/aws-actions/configure-aws-credentials) from 4.0.2 to 4.0.3. - [Release notes](https://github.com/aws-actions/configure-aws-credentials/releases) - [Changelog](https://github.com/aws-actions/configure-aws-credentials/blob/main/CHANGELOG.md) - [Commits](https://github.com/aws-actions/configure-aws-credentials/compare/e3dd6a429d7300a6a4c196c26e071d42e0343502...4fc4975a852c8cd99761e2de1f4ba73402e44dd9) --- updated-dependencies: - dependency-name: aws-actions/configure-aws-credentials dependency-type: direct:production update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Amna Mubashar --- .github/workflows/amazon_bedrock.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/amazon_bedrock.yml b/.github/workflows/amazon_bedrock.yml index be356ee4c..c9b3563bb 100644 --- a/.github/workflows/amazon_bedrock.yml +++ b/.github/workflows/amazon_bedrock.yml @@ -68,7 +68,7 @@ jobs: # Do not authenticate on pull requests from forks - name: AWS authentication if: github.event.pull_request.head.repo.full_name == github.repository - uses: aws-actions/configure-aws-credentials@e3dd6a429d7300a6a4c196c26e071d42e0343502 + uses: aws-actions/configure-aws-credentials@4fc4975a852c8cd99761e2de1f4ba73402e44dd9 with: aws-region: ${{ env.AWS_REGION }} role-to-assume: ${{ secrets.AWS_CI_ROLE_ARN }} From 4b003f2a4cf75d85fca6436fb91a4fe6d46e6006 Mon Sep 17 00:00:00 2001 From: Haystack Bot <73523382+HaystackBot@users.noreply.github.com> Date: Wed, 29 Jan 2025 16:19:36 +0100 Subject: [PATCH 229/229] Update changelog for integrations/ollama (#1332) Co-authored-by: Amnah199 <13835656+Amnah199@users.noreply.github.com> --- integrations/ollama/CHANGELOG.md | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/integrations/ollama/CHANGELOG.md b/integrations/ollama/CHANGELOG.md index e4e8f3602..7415f1e99 100644 --- a/integrations/ollama/CHANGELOG.md +++ b/integrations/ollama/CHANGELOG.md @@ -1,5 +1,12 @@ # Changelog +## [integrations/ollama-v2.3.0] - 2025-01-29 + +### 🚀 Features + +- Add `response_format` param to `OllamaChatGenerator` (#1326) + + ## [integrations/ollama-v2.2.0] - 2025-01-16 ### 🚀 Features