From 1d4e95a40c9b16affda0dfb6d180e2dc3534a8e1 Mon Sep 17 00:00:00 2001 From: Sebastian Husch Lee Date: Fri, 16 Feb 2024 13:50:22 +0100 Subject: [PATCH 01/20] Starting to add TransformersTextRouter --- .../routers/zero_shot_text_router.py | 113 ++++++++++++++++++ 1 file changed, 113 insertions(+) create mode 100644 haystack/components/routers/zero_shot_text_router.py diff --git a/haystack/components/routers/zero_shot_text_router.py b/haystack/components/routers/zero_shot_text_router.py new file mode 100644 index 0000000000..cdd751b24a --- /dev/null +++ b/haystack/components/routers/zero_shot_text_router.py @@ -0,0 +1,113 @@ +import logging +from pathlib import Path +from typing import Any, List, Dict, Optional, Union + +from haystack import component +from haystack.lazy_imports import LazyImport +from haystack.utils import ComponentDevice +from haystack.utils import Secret, deserialize_secrets_inplace + +logger = logging.getLogger(__name__) + +SUPPORTED_TASKS = ["zero-shot-classification"] + +with LazyImport(message="Run 'pip install transformers[torch,sentencepiece]'") as torch_and_transformers_import: + from huggingface_hub import model_info + from transformers import pipeline + + +@component +class TransformersTextRouter: + """ + Routes a text input onto different output connections depending on which label it has been categorized into. + This is useful for routing queries to different models in a pipeline depending on their categorization. + The set of labels to be used for categorization can be specified. + + Example usage in a retrieval pipeline that passes question-like queries to an embedding retriever and keyword-like + queries to a BM25 retriever: + + ```python + document_store = InMemoryDocumentStore() + p = Pipeline() + p.add_component(instance=TransformersTextRouter(), name="text_router") + p.add_component(instance=InMemoryBM25Retriever(document_store=document_store), name="retriever") + p.connect("text_router.en", "retriever.query") + p.run({"text_router": {"text": "What's your query?"}}) + ``` + """ + + def __init__( + self, + labels: List[str], + model: Union[str, Path] = "cross-encoder/ms-marco-MiniLM-L-6-v2", + device: Optional[ComponentDevice] = None, + token: Optional[Secret] = Secret.from_env_var("HF_API_TOKEN", strict=False), + pipeline_kwargs: Optional[Dict[str, Any]] = None, + ): + """ + :param labels: + :param pipeline_kwargs: Dictionary containing keyword arguments used to initialize the + Hugging Face pipeline for zero shot text classification. + """ + torch_and_transformers_import.check() + self.token = token + self.labels = labels + component.set_output_types(self, **{label: str for label in labels}) + + token = token.resolve_value() if token else None + + # check if the pipeline_kwargs contain the essential parameters + # otherwise, populate them with values from other init parameters + pipeline_kwargs.setdefault("model", model) + pipeline_kwargs.setdefault("token", token) + + device = ComponentDevice.resolve_device(device) + device.update_hf_kwargs(pipeline_kwargs, overwrite=False) + + # task identification and validation + task = "zero-shot-classification" + if task is None: + if "task" in pipeline_kwargs: + task = pipeline_kwargs["task"] + elif isinstance(pipeline_kwargs["model"], str): + task = model_info(pipeline_kwargs["model"], token=pipeline_kwargs["token"]).pipeline_tag + + if task not in SUPPORTED_TASKS: + raise ValueError( + f"Task '{task}' is not supported. " f"The supported tasks are: {', '.join(SUPPORTED_TASKS)}." + ) + pipeline_kwargs["task"] = task + + self.pipeline_kwargs = pipeline_kwargs + self.pipeline = None + + def _get_telemetry_data(self) -> Dict[str, Any]: + """ + Data that is sent to Posthog for usage analytics. + """ + if isinstance(self.pipeline_kwargs["model"], str): + return {"model": self.pipeline_kwargs["model"]} + return {"model": f"[object of type {type(self.pipeline_kwargs['model'])}]"} + + def warm_up(self): + if self.pipeline is None: + self.pipeline = pipeline(**self.pipeline_kwargs) + + def run(self, text: str) -> Dict[str, str]: + """ + Run the TransformersTextRouter. This method routes the text to one of the different edges based on which label + it has been categorized into. + + :param text: A str to route to one of the different edges. + """ + if self.pipeline is None: + raise RuntimeError( + "The zero-shot classification pipeline has not been loaded. Please call warm_up() before running." + ) + + if not isinstance(text, str): + raise TypeError("TransformersTextRouter expects a str as input.") + + prediction = self.pipeline(sequences=[text], candidate_labels=self.labels, multi_label=self.multi_label) + label = prediction[0]["labels"][0] + return {label: text} From fafc322ff6d3196bf531f42be65deda23c1a02d2 Mon Sep 17 00:00:00 2001 From: Sebastian Husch Lee Date: Fri, 16 Feb 2024 14:43:42 +0100 Subject: [PATCH 02/20] First pass at a TextRouter based off of the zero shot classification model on HuggingFace --- .../generators/hugging_face_local.py | 39 ++++-------- .../routers/zero_shot_text_router.py | 62 +++++++++++-------- haystack/utils/hf.py | 45 ++++++++++++++ 3 files changed, 91 insertions(+), 55 deletions(-) diff --git a/haystack/components/generators/hugging_face_local.py b/haystack/components/generators/hugging_face_local.py index 2760ce7bef..ed0b9510a1 100644 --- a/haystack/components/generators/hugging_face_local.py +++ b/haystack/components/generators/hugging_face_local.py @@ -13,9 +13,8 @@ SUPPORTED_TASKS = ["text-generation", "text2text-generation"] with LazyImport(message="Run 'pip install transformers[torch]'") as transformers_import: - from huggingface_hub import model_info from transformers import StoppingCriteriaList, pipeline - from haystack.utils.hf import StopWordsCriteria # pylint: disable=ungrouped-imports + from haystack.utils.hf import StopWordsCriteria, resolve_hf_pipeline_kwargs # pylint: disable=ungrouped-imports @component @@ -84,37 +83,21 @@ def __init__( """ transformers_import.check() - huggingface_pipeline_kwargs = huggingface_pipeline_kwargs or {} + self.token = token generation_kwargs = generation_kwargs or {} - self.token = token - token = token.resolve_value() if token else None - - # check if the huggingface_pipeline_kwargs contain the essential parameters - # otherwise, populate them with values from other init parameters - huggingface_pipeline_kwargs.setdefault("model", model) - huggingface_pipeline_kwargs.setdefault("token", token) - - device = ComponentDevice.resolve_device(device) - device.update_hf_kwargs(huggingface_pipeline_kwargs, overwrite=False) - - # task identification and validation - if task is None: - if "task" in huggingface_pipeline_kwargs: - task = huggingface_pipeline_kwargs["task"] - elif isinstance(huggingface_pipeline_kwargs["model"], str): - task = model_info( - huggingface_pipeline_kwargs["model"], token=huggingface_pipeline_kwargs["token"] - ).pipeline_tag - - if task not in SUPPORTED_TASKS: - raise ValueError( - f"Task '{task}' is not supported. " f"The supported tasks are: {', '.join(SUPPORTED_TASKS)}." - ) - huggingface_pipeline_kwargs["task"] = task + huggingface_pipeline_kwargs = resolve_hf_pipeline_kwargs( + huggingface_pipeline_kwargs=huggingface_pipeline_kwargs or {}, + model=model, + task=task, + supported_tasks=SUPPORTED_TASKS, + device=device, + token=token, + ) # if not specified, set return_full_text to False for text-generation # only generated text is returned (excluding prompt) + task = huggingface_pipeline_kwargs["task"] if task == "text-generation": generation_kwargs.setdefault("return_full_text", False) diff --git a/haystack/components/routers/zero_shot_text_router.py b/haystack/components/routers/zero_shot_text_router.py index cdd751b24a..3b8d92657b 100644 --- a/haystack/components/routers/zero_shot_text_router.py +++ b/haystack/components/routers/zero_shot_text_router.py @@ -2,7 +2,7 @@ from pathlib import Path from typing import Any, List, Dict, Optional, Union -from haystack import component +from haystack import component, default_to_dict, default_from_dict from haystack.lazy_imports import LazyImport from haystack.utils import ComponentDevice from haystack.utils import Secret, deserialize_secrets_inplace @@ -12,8 +12,8 @@ SUPPORTED_TASKS = ["zero-shot-classification"] with LazyImport(message="Run 'pip install transformers[torch,sentencepiece]'") as torch_and_transformers_import: - from huggingface_hub import model_info from transformers import pipeline + from haystack.utils.hf import resolve_hf_pipeline_kwargs, serialize_hf_model_kwargs, deserialize_hf_model_kwargs @component @@ -39,7 +39,7 @@ class TransformersTextRouter: def __init__( self, labels: List[str], - model: Union[str, Path] = "cross-encoder/ms-marco-MiniLM-L-6-v2", + model: Union[str, Path] = "MoritzLaurer/deberta-v3-base-zeroshot-v1.1-all-33", device: Optional[ComponentDevice] = None, token: Optional[Secret] = Secret.from_env_var("HF_API_TOKEN", strict=False), pipeline_kwargs: Optional[Dict[str, Any]] = None, @@ -50,34 +50,19 @@ def __init__( Hugging Face pipeline for zero shot text classification. """ torch_and_transformers_import.check() + self.token = token self.labels = labels component.set_output_types(self, **{label: str for label in labels}) - token = token.resolve_value() if token else None - - # check if the pipeline_kwargs contain the essential parameters - # otherwise, populate them with values from other init parameters - pipeline_kwargs.setdefault("model", model) - pipeline_kwargs.setdefault("token", token) - - device = ComponentDevice.resolve_device(device) - device.update_hf_kwargs(pipeline_kwargs, overwrite=False) - - # task identification and validation - task = "zero-shot-classification" - if task is None: - if "task" in pipeline_kwargs: - task = pipeline_kwargs["task"] - elif isinstance(pipeline_kwargs["model"], str): - task = model_info(pipeline_kwargs["model"], token=pipeline_kwargs["token"]).pipeline_tag - - if task not in SUPPORTED_TASKS: - raise ValueError( - f"Task '{task}' is not supported. " f"The supported tasks are: {', '.join(SUPPORTED_TASKS)}." - ) - pipeline_kwargs["task"] = task - + pipeline_kwargs = resolve_hf_pipeline_kwargs( + huggingface_pipeline_kwargs=pipeline_kwargs, + model=model, + task="zero-shot-classification", + supported_tasks=SUPPORTED_TASKS, + device=device, + token=token, + ) self.pipeline_kwargs = pipeline_kwargs self.pipeline = None @@ -93,6 +78,29 @@ def warm_up(self): if self.pipeline is None: self.pipeline = pipeline(**self.pipeline_kwargs) + def to_dict(self) -> Dict[str, Any]: + """ + Serialize this component to a dictionary. + """ + serialization_dict = default_to_dict( + self, pipeline_kwargs=self.pipeline_kwargs, token=self.token.to_dict() if self.token else None + ) + + pipeline_kwargs = serialization_dict["init_parameters"]["pipeline_kwargs"] + pipeline_kwargs.pop("token", None) + + serialize_hf_model_kwargs(pipeline_kwargs) + return serialization_dict + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "TransformersTextRouter": + """ + Deserialize this component from a dictionary. + """ + deserialize_secrets_inplace(data["init_parameters"], keys=["token"]) + deserialize_hf_model_kwargs(data["init_parameters"]["pipeline_kwargs"]) + return default_from_dict(cls, data) + def run(self, text: str) -> Dict[str, str]: """ Run the TransformersTextRouter. This method routes the text to one of the different edges based on which label diff --git a/haystack/utils/hf.py b/haystack/utils/hf.py index bc8a5e721d..01c056c6ac 100644 --- a/haystack/utils/hf.py +++ b/haystack/utils/hf.py @@ -15,6 +15,7 @@ import torch with LazyImport(message="Run 'pip install transformers'") as transformers_import: + from huggingface_hub import model_info from huggingface_hub.utils import RepositoryNotFoundError from huggingface_hub import InferenceClient, HfApi @@ -94,6 +95,50 @@ def resolve_hf_device_map(device: Optional[ComponentDevice], model_kwargs: Optio return model_kwargs +def resolve_hf_pipeline_kwargs( + huggingface_pipeline_kwargs: Dict[str, Any], + model: str, + task: Optional[str], + supported_tasks: List[str], + device: Optional[ComponentDevice], + token: Optional[Secret], +) -> Dict[str, Any]: + """ + Resolve the HuggingFace pipeline keyword arguments based on explicit user inputs. + + :param huggingface_pipeline_kwargs: Dictionary containing keyword arguments used to initialize a + Hugging Face pipeline. + :param model: The name or path of a Hugging Face model for on the HuggingFace Hub. + :param task: The task for the Hugging Face pipeline. + :param supported_tasks: The list of supported tasks to check the task of the model against. If the task of the model + is not present within this list then a ValueError is thrown. + :param device: The device on which the model is loaded. If `None`, the default device is automatically + selected. If a device/device map is specified in `huggingface_pipeline_kwargs`, it overrides this parameter. + :param token: The token to use as HTTP bearer authorization for remote files. + If the token is also specified in the `huggingface_pipeline_kwargs`, this parameter will be ignored. + """ + transformers_import.check() + + token = token.resolve_value() if token else None + # check if the huggingface_pipeline_kwargs contain the essential parameters + # otherwise, populate them with values from other init parameters + huggingface_pipeline_kwargs.setdefault("model", model) + huggingface_pipeline_kwargs.setdefault("token", token) + + device = ComponentDevice.resolve_device(device) + device.update_hf_kwargs(huggingface_pipeline_kwargs, overwrite=False) + + # task identification and validation + task = task or huggingface_pipeline_kwargs.get("task") + if task is None and isinstance(huggingface_pipeline_kwargs["model"], str): + task = model_info(huggingface_pipeline_kwargs["model"], token=huggingface_pipeline_kwargs["token"]).pipeline_tag + + if task not in supported_tasks: + raise ValueError(f"Task '{task}' is not supported. " f"The supported tasks are: {', '.join(supported_tasks)}.") + huggingface_pipeline_kwargs["task"] = task + return huggingface_pipeline_kwargs + + def list_inference_deployed_models(headers: Optional[Dict] = None) -> List[str]: """ List all currently deployed models on HF TGI free tier From 44237ebb4e4d6be977280acdbc612e608d52c990 Mon Sep 17 00:00:00 2001 From: Sebastian Husch Lee Date: Fri, 16 Feb 2024 14:53:49 +0100 Subject: [PATCH 03/20] Fix pylint --- haystack/components/routers/zero_shot_text_router.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/haystack/components/routers/zero_shot_text_router.py b/haystack/components/routers/zero_shot_text_router.py index 3b8d92657b..92d53c34e8 100644 --- a/haystack/components/routers/zero_shot_text_router.py +++ b/haystack/components/routers/zero_shot_text_router.py @@ -39,7 +39,7 @@ class TransformersTextRouter: def __init__( self, labels: List[str], - model: Union[str, Path] = "MoritzLaurer/deberta-v3-base-zeroshot-v1.1-all-33", + model: str = "MoritzLaurer/deberta-v3-base-zeroshot-v1.1-all-33", device: Optional[ComponentDevice] = None, token: Optional[Secret] = Secret.from_env_var("HF_API_TOKEN", strict=False), pipeline_kwargs: Optional[Dict[str, Any]] = None, From 196182c7eec9d3f976b96a454538a0911fa6a601 Mon Sep 17 00:00:00 2001 From: Sebastian Husch Lee Date: Fri, 16 Feb 2024 14:54:22 +0100 Subject: [PATCH 04/20] Remove unneeded imports --- haystack/components/routers/zero_shot_text_router.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/haystack/components/routers/zero_shot_text_router.py b/haystack/components/routers/zero_shot_text_router.py index 92d53c34e8..fd30db6681 100644 --- a/haystack/components/routers/zero_shot_text_router.py +++ b/haystack/components/routers/zero_shot_text_router.py @@ -1,6 +1,5 @@ import logging -from pathlib import Path -from typing import Any, List, Dict, Optional, Union +from typing import Any, List, Dict, Optional from haystack import component, default_to_dict, default_from_dict from haystack.lazy_imports import LazyImport From d008fdc8be0a4fa1b2a8d69f2ca37a4f6d891b4b Mon Sep 17 00:00:00 2001 From: Sebastian Husch Lee Date: Fri, 23 Feb 2024 16:03:52 +0100 Subject: [PATCH 05/20] Update documentation example --- .../routers/zero_shot_text_router.py | 32 +++++++++++++++---- 1 file changed, 25 insertions(+), 7 deletions(-) diff --git a/haystack/components/routers/zero_shot_text_router.py b/haystack/components/routers/zero_shot_text_router.py index fd30db6681..aa973d42ac 100644 --- a/haystack/components/routers/zero_shot_text_router.py +++ b/haystack/components/routers/zero_shot_text_router.py @@ -8,7 +8,6 @@ logger = logging.getLogger(__name__) -SUPPORTED_TASKS = ["zero-shot-classification"] with LazyImport(message="Run 'pip install transformers[torch,sentencepiece]'") as torch_and_transformers_import: from transformers import pipeline @@ -16,7 +15,7 @@ @component -class TransformersTextRouter: +class ZeroShotTextRouter: """ Routes a text input onto different output connections depending on which label it has been categorized into. This is useful for routing queries to different models in a pipeline depending on their categorization. @@ -28,10 +27,29 @@ class TransformersTextRouter: ```python document_store = InMemoryDocumentStore() p = Pipeline() - p.add_component(instance=TransformersTextRouter(), name="text_router") - p.add_component(instance=InMemoryBM25Retriever(document_store=document_store), name="retriever") - p.connect("text_router.en", "retriever.query") + p.add_component(instance=ZeroShotTextRouter(labels=["passage", "query"]), name="text_router") + p.add_component( + instance=SentenceTransformersTextEmbedder( + document_store=document_store, model="intfloat/e5-base-v2", prefix="passage: " + ), + name="passage_embedder" + ) + p.add_component( + instance=SentenceTransformersTextEmbedder( + document_store=document_store, model="intfloat/e5-base-v2", prefix="query: " + ), + name="query_embedder" + ) + p.connect("text_router.passage", "passage_embedder.text") + p.connect("text_router.query", "query_embedder.text") + # Query Example p.run({"text_router": {"text": "What's your query?"}}) + # Passage Example + p.run({ + "text_router":{ + "text": "Last week I upgraded my iOS version and ever since then my phone has been overheating whenever I use your app." + } + }) ``` """ @@ -58,7 +76,7 @@ def __init__( huggingface_pipeline_kwargs=pipeline_kwargs, model=model, task="zero-shot-classification", - supported_tasks=SUPPORTED_TASKS, + supported_tasks=["zero-shot-classification"], device=device, token=token, ) @@ -92,7 +110,7 @@ def to_dict(self) -> Dict[str, Any]: return serialization_dict @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "TransformersTextRouter": + def from_dict(cls, data: Dict[str, Any]) -> "ZeroShotTextRouter": """ Deserialize this component from a dictionary. """ From 093b5331f86ec66364a41bbc242c1b4eec438f61 Mon Sep 17 00:00:00 2001 From: Sebastian Husch Lee Date: Fri, 23 Feb 2024 16:07:17 +0100 Subject: [PATCH 06/20] Update error message strings --- haystack/components/routers/zero_shot_text_router.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/haystack/components/routers/zero_shot_text_router.py b/haystack/components/routers/zero_shot_text_router.py index aa973d42ac..0aca7a7c33 100644 --- a/haystack/components/routers/zero_shot_text_router.py +++ b/haystack/components/routers/zero_shot_text_router.py @@ -120,7 +120,7 @@ def from_dict(cls, data: Dict[str, Any]) -> "ZeroShotTextRouter": def run(self, text: str) -> Dict[str, str]: """ - Run the TransformersTextRouter. This method routes the text to one of the different edges based on which label + Run the ZeroShotTextRouter. This method routes the text to one of the different edges based on which label it has been categorized into. :param text: A str to route to one of the different edges. @@ -131,7 +131,7 @@ def run(self, text: str) -> Dict[str, str]: ) if not isinstance(text, str): - raise TypeError("TransformersTextRouter expects a str as input.") + raise TypeError("ZeroShotTextRouter expects a str as input.") prediction = self.pipeline(sequences=[text], candidate_labels=self.labels, multi_label=self.multi_label) label = prediction[0]["labels"][0] From a78231fbe2c798b6ec6065a5f7b8a33c3a584394 Mon Sep 17 00:00:00 2001 From: Sebastian Husch Lee Date: Fri, 23 Feb 2024 16:19:35 +0100 Subject: [PATCH 07/20] Starting to add unit tests --- .../routers/zero_shot_text_router.py | 2 +- .../routers/test_zero_shot_text_router.py | 35 +++++++++++++++++++ 2 files changed, 36 insertions(+), 1 deletion(-) create mode 100644 test/components/routers/test_zero_shot_text_router.py diff --git a/haystack/components/routers/zero_shot_text_router.py b/haystack/components/routers/zero_shot_text_router.py index 0aca7a7c33..7e43607d0d 100644 --- a/haystack/components/routers/zero_shot_text_router.py +++ b/haystack/components/routers/zero_shot_text_router.py @@ -73,7 +73,7 @@ def __init__( component.set_output_types(self, **{label: str for label in labels}) pipeline_kwargs = resolve_hf_pipeline_kwargs( - huggingface_pipeline_kwargs=pipeline_kwargs, + huggingface_pipeline_kwargs=pipeline_kwargs or {}, model=model, task="zero-shot-classification", supported_tasks=["zero-shot-classification"], diff --git a/test/components/routers/test_zero_shot_text_router.py b/test/components/routers/test_zero_shot_text_router.py new file mode 100644 index 0000000000..60051f6771 --- /dev/null +++ b/test/components/routers/test_zero_shot_text_router.py @@ -0,0 +1,35 @@ +import pytest +from unittest.mock import patch +from haystack.components.routers.zero_shot_text_router import ZeroShotTextRouter + + +@pytest.fixture +def text_router(): + return ZeroShotTextRouter(labels=["query", "passage"]) + + +class TestFileTypeRouter: + # def test_to_dict(self): + # router = ZeroShotTextRouter(labels=["query", "passage"]) + # router_dict = router.to_dict() + # pass + + # def test_from_dict(self): + # pass + + def test_run_error(self): + router = ZeroShotTextRouter(labels=["query", "passage"]) + with pytest.raises(RuntimeError): + router.run(text="test") + + @patch("haystack.components.routers.zero_shot_text_router.pipeline") + def test_run_not_str_error(self, hf_pipeline_mock): + hf_pipeline_mock.return_value = " " + router = ZeroShotTextRouter(labels=["query", "passage"]) + router.warm_up() + with pytest.raises(TypeError): + router.run(text=["wrong_input"]) + + # @pytest.mark.integration + # def test_run(self): + # pass From 5695c9f051a2ddf6817ba69106fc02ae94f46927 Mon Sep 17 00:00:00 2001 From: Sebastian Husch Lee Date: Fri, 23 Feb 2024 16:23:01 +0100 Subject: [PATCH 08/20] Release notes --- .../notes/zero-shot-text-router-f5090589e652197c.yaml | 4 ++++ 1 file changed, 4 insertions(+) create mode 100644 releasenotes/notes/zero-shot-text-router-f5090589e652197c.yaml diff --git a/releasenotes/notes/zero-shot-text-router-f5090589e652197c.yaml b/releasenotes/notes/zero-shot-text-router-f5090589e652197c.yaml new file mode 100644 index 0000000000..20c158e01e --- /dev/null +++ b/releasenotes/notes/zero-shot-text-router-f5090589e652197c.yaml @@ -0,0 +1,4 @@ +--- +features: + - | + Add a Zero Shot Text Router that uses an NLI model from HF to classify texts based on a set of provided labels and routes them based on the label they were classified with. From 8c0a786e54f1862eaedcb3dfd1f903cceb1c4e3e Mon Sep 17 00:00:00 2001 From: Sebastian Husch Lee Date: Wed, 28 Feb 2024 14:36:18 +0100 Subject: [PATCH 09/20] Fix pylint --- haystack/components/routers/zero_shot_text_router.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/haystack/components/routers/zero_shot_text_router.py b/haystack/components/routers/zero_shot_text_router.py index 7e43607d0d..99bc87f005 100644 --- a/haystack/components/routers/zero_shot_text_router.py +++ b/haystack/components/routers/zero_shot_text_router.py @@ -11,7 +11,11 @@ with LazyImport(message="Run 'pip install transformers[torch,sentencepiece]'") as torch_and_transformers_import: from transformers import pipeline - from haystack.utils.hf import resolve_hf_pipeline_kwargs, serialize_hf_model_kwargs, deserialize_hf_model_kwargs + from haystack.utils.hf import ( # pylint: disable=ungrouped-imports + resolve_hf_pipeline_kwargs, + serialize_hf_model_kwargs, + deserialize_hf_model_kwargs, + ) @component From 0fb6db4216a2c0a18c6cd5049ac308b463d0a4aa Mon Sep 17 00:00:00 2001 From: Sebastian Husch Lee Date: Wed, 28 Feb 2024 14:54:11 +0100 Subject: [PATCH 10/20] Add tests for to dict and from dict --- .../routers/zero_shot_text_router.py | 5 +- .../routers/test_zero_shot_text_router.py | 46 ++++++++++++++++--- 2 files changed, 44 insertions(+), 7 deletions(-) diff --git a/haystack/components/routers/zero_shot_text_router.py b/haystack/components/routers/zero_shot_text_router.py index 99bc87f005..ebe90dc272 100644 --- a/haystack/components/routers/zero_shot_text_router.py +++ b/haystack/components/routers/zero_shot_text_router.py @@ -104,7 +104,10 @@ def to_dict(self) -> Dict[str, Any]: Serialize this component to a dictionary. """ serialization_dict = default_to_dict( - self, pipeline_kwargs=self.pipeline_kwargs, token=self.token.to_dict() if self.token else None + self, + labels=self.labels, + pipeline_kwargs=self.pipeline_kwargs, + token=self.token.to_dict() if self.token else None, ) pipeline_kwargs = serialization_dict["init_parameters"]["pipeline_kwargs"] diff --git a/test/components/routers/test_zero_shot_text_router.py b/test/components/routers/test_zero_shot_text_router.py index 60051f6771..903a7a01de 100644 --- a/test/components/routers/test_zero_shot_text_router.py +++ b/test/components/routers/test_zero_shot_text_router.py @@ -1,6 +1,7 @@ import pytest from unittest.mock import patch from haystack.components.routers.zero_shot_text_router import ZeroShotTextRouter +from haystack.utils import ComponentDevice, Secret @pytest.fixture @@ -9,13 +10,46 @@ def text_router(): class TestFileTypeRouter: - # def test_to_dict(self): - # router = ZeroShotTextRouter(labels=["query", "passage"]) - # router_dict = router.to_dict() - # pass + def test_to_dict(self): + router = ZeroShotTextRouter(labels=["query", "passage"]) + router_dict = router.to_dict() + assert router_dict == { + "type": "haystack.components.routers.zero_shot_text_router.ZeroShotTextRouter", + "init_parameters": { + "labels": ["query", "passage"], + "token": {"env_vars": ["HF_API_TOKEN"], "strict": False, "type": "env_var"}, + "pipeline_kwargs": { + "model": "MoritzLaurer/deberta-v3-base-zeroshot-v1.1-all-33", + "device": ComponentDevice.resolve_device(None).to_hf(), + "task": "zero-shot-classification", + }, + }, + } - # def test_from_dict(self): - # pass + def test_from_dict(self): + data = { + "type": "haystack.components.routers.zero_shot_text_router.ZeroShotTextRouter", + "init_parameters": { + "labels": ["query", "passage"], + "token": {"env_vars": ["HF_API_TOKEN"], "strict": False, "type": "env_var"}, + "pipeline_kwargs": { + "model": "MoritzLaurer/deberta-v3-base-zeroshot-v1.1-all-33", + "device": ComponentDevice.resolve_device(None).to_hf(), + "task": "zero-shot-classification", + }, + }, + } + + component = ZeroShotTextRouter.from_dict(data) + assert component.labels == ["query", "passage"] + assert component.pipeline is None + assert component.token == Secret.from_dict({"env_vars": ["HF_API_TOKEN"], "strict": False, "type": "env_var"}) + assert component.pipeline_kwargs == { + "model": "MoritzLaurer/deberta-v3-base-zeroshot-v1.1-all-33", + "device": ComponentDevice.resolve_device(None).to_hf(), + "task": "zero-shot-classification", + "token": None, + } def test_run_error(self): router = ZeroShotTextRouter(labels=["query", "passage"]) From fa7d160f50a7cd8c9f143a2409629440be1d9a8c Mon Sep 17 00:00:00 2001 From: Sebastian Husch Lee Date: Wed, 28 Feb 2024 15:01:34 +0100 Subject: [PATCH 11/20] Update patches in tests to be correct with respect to changes --- .../generators/test_hugging_face_local_generator.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/components/generators/test_hugging_face_local_generator.py b/test/components/generators/test_hugging_face_local_generator.py index 714decc56a..efd77934ad 100644 --- a/test/components/generators/test_hugging_face_local_generator.py +++ b/test/components/generators/test_hugging_face_local_generator.py @@ -11,7 +11,7 @@ class TestHuggingFaceLocalGenerator: - @patch("haystack.components.generators.hugging_face_local.model_info") + @patch("haystack.utils.hf.model_info") def test_init_default(self, model_info_mock, monkeypatch): monkeypatch.delenv("HF_API_TOKEN", raising=False) model_info_mock.return_value.pipeline_tag = "text2text-generation" @@ -73,7 +73,7 @@ def test_init_task_in_huggingface_pipeline_kwargs(self): "device": ComponentDevice.resolve_device(None).to_hf(), } - @patch("haystack.components.generators.hugging_face_local.model_info") + @patch("haystack.utils.hf.model_info") def test_init_task_inferred_from_model_name(self, model_info_mock): model_info_mock.return_value.pipeline_tag = "text2text-generation" generator = HuggingFaceLocalGenerator(model="google/flan-t5-base", token=None) @@ -137,7 +137,7 @@ def test_init_fails_with_both_stopwords_and_stoppingcriteria(self): generation_kwargs={"stopping_criteria": "fake-stopping-criteria"}, ) - @patch("haystack.components.generators.hugging_face_local.model_info") + @patch("haystack.utils.hf.model_info") def test_to_dict_default(self, model_info_mock): model_info_mock.return_value.pipeline_tag = "text2text-generation" From 8dcff28db65d827558afdd8cc843b8fe29d2da5c Mon Sep 17 00:00:00 2001 From: Sebastian Husch Lee Date: Thu, 29 Feb 2024 09:00:24 +0100 Subject: [PATCH 12/20] Doc strings and fixes --- .../routers/zero_shot_text_router.py | 19 +++++++++++++++++-- .../routers/test_zero_shot_text_router.py | 15 +++++++-------- 2 files changed, 24 insertions(+), 10 deletions(-) diff --git a/haystack/components/routers/zero_shot_text_router.py b/haystack/components/routers/zero_shot_text_router.py index ebe90dc272..94d6cf6ec2 100644 --- a/haystack/components/routers/zero_shot_text_router.py +++ b/haystack/components/routers/zero_shot_text_router.py @@ -60,13 +60,25 @@ class ZeroShotTextRouter: def __init__( self, labels: List[str], + multi_label: bool = False, model: str = "MoritzLaurer/deberta-v3-base-zeroshot-v1.1-all-33", device: Optional[ComponentDevice] = None, token: Optional[Secret] = Secret.from_env_var("HF_API_TOKEN", strict=False), pipeline_kwargs: Optional[Dict[str, Any]] = None, ): """ - :param labels: + :param labels: The set of possible class labels to classify each sequence into. Can be a single label, + a string of comma-separated labels, or a list of labels. + :param multi_label: Whether or not multiple candidate labels can be true. + If False, the scores are normalized such that the sum of the label likelihoods for each sequence is 1. + If True, the labels are considered independent and probabilities are normalized for each candidate by + doing a softmax of the entailment score vs. the contradiction score. + :param model: The name or path of a Hugging Face model for zero-shot text classification. + :param device: The device on which the model is loaded. If `None`, the default device is automatically + selected. If a device/device map is specified in `pipeline_kwargs`, it overrides this parameter. + :param token: The API token used to download private models from Hugging Face. + If this parameter is set to `True`, the token generated when running + `transformers-cli login` (stored in ~/.huggingface) is used. :param pipeline_kwargs: Dictionary containing keyword arguments used to initialize the Hugging Face pipeline for zero shot text classification. """ @@ -74,6 +86,7 @@ def __init__( self.token = token self.labels = labels + self.multi_label = multi_label component.set_output_types(self, **{label: str for label in labels}) pipeline_kwargs = resolve_hf_pipeline_kwargs( @@ -141,5 +154,7 @@ def run(self, text: str) -> Dict[str, str]: raise TypeError("ZeroShotTextRouter expects a str as input.") prediction = self.pipeline(sequences=[text], candidate_labels=self.labels, multi_label=self.multi_label) - label = prediction[0]["labels"][0] + predicted_scores = prediction[0]["scores"] + max_score_index = max(range(len(predicted_scores)), key=predicted_scores.__getitem__) + label = prediction[0]["labels"][max_score_index] return {label: text} diff --git a/test/components/routers/test_zero_shot_text_router.py b/test/components/routers/test_zero_shot_text_router.py index 903a7a01de..43fd0ea44e 100644 --- a/test/components/routers/test_zero_shot_text_router.py +++ b/test/components/routers/test_zero_shot_text_router.py @@ -4,11 +4,6 @@ from haystack.utils import ComponentDevice, Secret -@pytest.fixture -def text_router(): - return ZeroShotTextRouter(labels=["query", "passage"]) - - class TestFileTypeRouter: def test_to_dict(self): router = ZeroShotTextRouter(labels=["query", "passage"]) @@ -64,6 +59,10 @@ def test_run_not_str_error(self, hf_pipeline_mock): with pytest.raises(TypeError): router.run(text=["wrong_input"]) - # @pytest.mark.integration - # def test_run(self): - # pass + @pytest.mark.integration + def test_run(self): + router = ZeroShotTextRouter(labels=["query", "passage"]) + router.warm_up() + out = router.run("What is the color of the sky?") + assert router.pipeline is not None + assert out == {"query": "What is the color of the sky?"} From cd618ff04368cb0ff7d12410ccc7e92dd2ff2d89 Mon Sep 17 00:00:00 2001 From: Sebastian Husch Lee Date: Wed, 13 Mar 2024 12:02:05 +0100 Subject: [PATCH 13/20] Adding more tests --- .../routers/test_zero_shot_text_router.py | 23 +++++++++++++++++-- 1 file changed, 21 insertions(+), 2 deletions(-) diff --git a/test/components/routers/test_zero_shot_text_router.py b/test/components/routers/test_zero_shot_text_router.py index 43fd0ea44e..dc91f6cd19 100644 --- a/test/components/routers/test_zero_shot_text_router.py +++ b/test/components/routers/test_zero_shot_text_router.py @@ -1,10 +1,12 @@ -import pytest from unittest.mock import patch + +import pytest + from haystack.components.routers.zero_shot_text_router import ZeroShotTextRouter from haystack.utils import ComponentDevice, Secret -class TestFileTypeRouter: +class TestZeroShotTextRouter: def test_to_dict(self): router = ZeroShotTextRouter(labels=["query", "passage"]) router_dict = router.to_dict() @@ -46,6 +48,12 @@ def test_from_dict(self): "token": None, } + @patch("haystack.components.routers.zero_shot_text_router.pipeline") + def test_warm_up(self, hf_pipeline_mock): + router = ZeroShotTextRouter(labels=["query", "passage"]) + router.warm_up() + assert router.pipeline is not None + def test_run_error(self): router = ZeroShotTextRouter(labels=["query", "passage"]) with pytest.raises(RuntimeError): @@ -59,6 +67,17 @@ def test_run_not_str_error(self, hf_pipeline_mock): with pytest.raises(TypeError): router.run(text=["wrong_input"]) + @patch("haystack.components.routers.zero_shot_text_router.pipeline") + def test_run_unit(self, hf_pipeline_mock): + hf_pipeline_mock.return_value = [ + {"sequence": "What is the color of the sky?", "labels": ["query", "passage"], "scores": [0.9, 0.1]} + ] + router = ZeroShotTextRouter(labels=["query", "passage"]) + router.pipeline = hf_pipeline_mock + out = router.run("What is the color of the sky?") + assert router.pipeline is not None + assert out == {"query": "What is the color of the sky?"} + @pytest.mark.integration def test_run(self): router = ZeroShotTextRouter(labels=["query", "passage"]) From 1f2505c372a27a8469361cc852b9818e6eca80a6 Mon Sep 17 00:00:00 2001 From: Sebastian Husch Lee Date: Wed, 13 Mar 2024 12:05:21 +0100 Subject: [PATCH 14/20] Change name --- .../routers/zero_shot_text_router.py | 20 ++++++++--------- .../routers/test_zero_shot_text_router.py | 22 +++++++++---------- 2 files changed, 21 insertions(+), 21 deletions(-) diff --git a/haystack/components/routers/zero_shot_text_router.py b/haystack/components/routers/zero_shot_text_router.py index 94d6cf6ec2..396d6fa0e2 100644 --- a/haystack/components/routers/zero_shot_text_router.py +++ b/haystack/components/routers/zero_shot_text_router.py @@ -1,25 +1,25 @@ import logging -from typing import Any, List, Dict, Optional +from typing import Any, Dict, List, Optional -from haystack import component, default_to_dict, default_from_dict +from haystack import component, default_from_dict, default_to_dict from haystack.lazy_imports import LazyImport -from haystack.utils import ComponentDevice -from haystack.utils import Secret, deserialize_secrets_inplace +from haystack.utils import ComponentDevice, Secret, deserialize_secrets_inplace logger = logging.getLogger(__name__) with LazyImport(message="Run 'pip install transformers[torch,sentencepiece]'") as torch_and_transformers_import: from transformers import pipeline + from haystack.utils.hf import ( # pylint: disable=ungrouped-imports + deserialize_hf_model_kwargs, resolve_hf_pipeline_kwargs, serialize_hf_model_kwargs, - deserialize_hf_model_kwargs, ) @component -class ZeroShotTextRouter: +class TransformersZeroShotTextRouter: """ Routes a text input onto different output connections depending on which label it has been categorized into. This is useful for routing queries to different models in a pipeline depending on their categorization. @@ -31,7 +31,7 @@ class ZeroShotTextRouter: ```python document_store = InMemoryDocumentStore() p = Pipeline() - p.add_component(instance=ZeroShotTextRouter(labels=["passage", "query"]), name="text_router") + p.add_component(instance=TransformersZeroShotTextRouter(labels=["passage", "query"]), name="text_router") p.add_component( instance=SentenceTransformersTextEmbedder( document_store=document_store, model="intfloat/e5-base-v2", prefix="passage: " @@ -130,7 +130,7 @@ def to_dict(self) -> Dict[str, Any]: return serialization_dict @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "ZeroShotTextRouter": + def from_dict(cls, data: Dict[str, Any]) -> "TransformersZeroShotTextRouter": """ Deserialize this component from a dictionary. """ @@ -140,7 +140,7 @@ def from_dict(cls, data: Dict[str, Any]) -> "ZeroShotTextRouter": def run(self, text: str) -> Dict[str, str]: """ - Run the ZeroShotTextRouter. This method routes the text to one of the different edges based on which label + Run the TransformersZeroShotTextRouter. This method routes the text to one of the different edges based on which label it has been categorized into. :param text: A str to route to one of the different edges. @@ -151,7 +151,7 @@ def run(self, text: str) -> Dict[str, str]: ) if not isinstance(text, str): - raise TypeError("ZeroShotTextRouter expects a str as input.") + raise TypeError("TransformersZeroShotTextRouter expects a str as input.") prediction = self.pipeline(sequences=[text], candidate_labels=self.labels, multi_label=self.multi_label) predicted_scores = prediction[0]["scores"] diff --git a/test/components/routers/test_zero_shot_text_router.py b/test/components/routers/test_zero_shot_text_router.py index dc91f6cd19..f873b427c1 100644 --- a/test/components/routers/test_zero_shot_text_router.py +++ b/test/components/routers/test_zero_shot_text_router.py @@ -2,16 +2,16 @@ import pytest -from haystack.components.routers.zero_shot_text_router import ZeroShotTextRouter +from haystack.components.routers.zero_shot_text_router import TransformersZeroShotTextRouter from haystack.utils import ComponentDevice, Secret -class TestZeroShotTextRouter: +class TestTransformersZeroShotTextRouter: def test_to_dict(self): - router = ZeroShotTextRouter(labels=["query", "passage"]) + router = TransformersZeroShotTextRouter(labels=["query", "passage"]) router_dict = router.to_dict() assert router_dict == { - "type": "haystack.components.routers.zero_shot_text_router.ZeroShotTextRouter", + "type": "haystack.components.routers.zero_shot_text_router.TransformersZeroShotTextRouter", "init_parameters": { "labels": ["query", "passage"], "token": {"env_vars": ["HF_API_TOKEN"], "strict": False, "type": "env_var"}, @@ -25,7 +25,7 @@ def test_to_dict(self): def test_from_dict(self): data = { - "type": "haystack.components.routers.zero_shot_text_router.ZeroShotTextRouter", + "type": "haystack.components.routers.zero_shot_text_router.TransformersZeroShotTextRouter", "init_parameters": { "labels": ["query", "passage"], "token": {"env_vars": ["HF_API_TOKEN"], "strict": False, "type": "env_var"}, @@ -37,7 +37,7 @@ def test_from_dict(self): }, } - component = ZeroShotTextRouter.from_dict(data) + component = TransformersZeroShotTextRouter.from_dict(data) assert component.labels == ["query", "passage"] assert component.pipeline is None assert component.token == Secret.from_dict({"env_vars": ["HF_API_TOKEN"], "strict": False, "type": "env_var"}) @@ -50,19 +50,19 @@ def test_from_dict(self): @patch("haystack.components.routers.zero_shot_text_router.pipeline") def test_warm_up(self, hf_pipeline_mock): - router = ZeroShotTextRouter(labels=["query", "passage"]) + router = TransformersZeroShotTextRouter(labels=["query", "passage"]) router.warm_up() assert router.pipeline is not None def test_run_error(self): - router = ZeroShotTextRouter(labels=["query", "passage"]) + router = TransformersZeroShotTextRouter(labels=["query", "passage"]) with pytest.raises(RuntimeError): router.run(text="test") @patch("haystack.components.routers.zero_shot_text_router.pipeline") def test_run_not_str_error(self, hf_pipeline_mock): hf_pipeline_mock.return_value = " " - router = ZeroShotTextRouter(labels=["query", "passage"]) + router = TransformersZeroShotTextRouter(labels=["query", "passage"]) router.warm_up() with pytest.raises(TypeError): router.run(text=["wrong_input"]) @@ -72,7 +72,7 @@ def test_run_unit(self, hf_pipeline_mock): hf_pipeline_mock.return_value = [ {"sequence": "What is the color of the sky?", "labels": ["query", "passage"], "scores": [0.9, 0.1]} ] - router = ZeroShotTextRouter(labels=["query", "passage"]) + router = TransformersZeroShotTextRouter(labels=["query", "passage"]) router.pipeline = hf_pipeline_mock out = router.run("What is the color of the sky?") assert router.pipeline is not None @@ -80,7 +80,7 @@ def test_run_unit(self, hf_pipeline_mock): @pytest.mark.integration def test_run(self): - router = ZeroShotTextRouter(labels=["query", "passage"]) + router = TransformersZeroShotTextRouter(labels=["query", "passage"]) router.warm_up() out = router.run("What is the color of the sky?") assert router.pipeline is not None From 160e257056ed4e9d594ca2ca03d572a33a5a66a0 Mon Sep 17 00:00:00 2001 From: Sebastian Husch Lee Date: Wed, 13 Mar 2024 12:10:44 +0100 Subject: [PATCH 15/20] Adding to init --- haystack/components/routers/__init__.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/haystack/components/routers/__init__.py b/haystack/components/routers/__init__.py index 68aaeb3f53..39abe919d0 100644 --- a/haystack/components/routers/__init__.py +++ b/haystack/components/routers/__init__.py @@ -2,5 +2,12 @@ from haystack.components.routers.file_type_router import FileTypeRouter from haystack.components.routers.metadata_router import MetadataRouter from haystack.components.routers.text_language_router import TextLanguageRouter +from haystack.components.routers.zero_shot_text_router import TransformersZeroShotTextRouter -__all__ = ["FileTypeRouter", "MetadataRouter", "TextLanguageRouter", "ConditionalRouter"] +__all__ = [ + "FileTypeRouter", + "MetadataRouter", + "TextLanguageRouter", + "ConditionalRouter", + "TransformersZeroShotTextRouter", +] From 3c654a6dda024130351b4ec143717f166285fc05 Mon Sep 17 00:00:00 2001 From: Sebastian Husch Lee Date: Thu, 14 Mar 2024 09:59:00 +0100 Subject: [PATCH 16/20] Use Haystack logger --- haystack/components/routers/zero_shot_text_router.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/haystack/components/routers/zero_shot_text_router.py b/haystack/components/routers/zero_shot_text_router.py index 396d6fa0e2..92d7d71631 100644 --- a/haystack/components/routers/zero_shot_text_router.py +++ b/haystack/components/routers/zero_shot_text_router.py @@ -1,7 +1,6 @@ -import logging from typing import Any, Dict, List, Optional -from haystack import component, default_from_dict, default_to_dict +from haystack import component, default_from_dict, default_to_dict, logging from haystack.lazy_imports import LazyImport from haystack.utils import ComponentDevice, Secret, deserialize_secrets_inplace From bb3c6a0f10e774168f62405ceff277b13c0a1e2d Mon Sep 17 00:00:00 2001 From: Sebastian Husch Lee Date: Thu, 14 Mar 2024 15:09:46 +0100 Subject: [PATCH 17/20] Beef up docstrings --- .../routers/zero_shot_text_router.py | 25 ++++++++++++++++--- 1 file changed, 22 insertions(+), 3 deletions(-) diff --git a/haystack/components/routers/zero_shot_text_router.py b/haystack/components/routers/zero_shot_text_router.py index 92d7d71631..515cb4af2f 100644 --- a/haystack/components/routers/zero_shot_text_router.py +++ b/haystack/components/routers/zero_shot_text_router.py @@ -108,12 +108,18 @@ def _get_telemetry_data(self) -> Dict[str, Any]: return {"model": f"[object of type {type(self.pipeline_kwargs['model'])}]"} def warm_up(self): + """ + Initializes the component. + """ if self.pipeline is None: self.pipeline = pipeline(**self.pipeline_kwargs) def to_dict(self) -> Dict[str, Any]: """ - Serialize this component to a dictionary. + Serializes the component to a dictionary. + + :returns: + Dictionary with serialized data. """ serialization_dict = default_to_dict( self, @@ -131,18 +137,31 @@ def to_dict(self) -> Dict[str, Any]: @classmethod def from_dict(cls, data: Dict[str, Any]) -> "TransformersZeroShotTextRouter": """ - Deserialize this component from a dictionary. + Deserializes the component from a dictionary. + + :param data: + Dictionary to deserialize from. + :returns: + Deserialized component. """ deserialize_secrets_inplace(data["init_parameters"], keys=["token"]) deserialize_hf_model_kwargs(data["init_parameters"]["pipeline_kwargs"]) return default_from_dict(cls, data) - def run(self, text: str) -> Dict[str, str]: + @component.output_types(documents=Dict[str, str]) + def run(self, text: str): """ Run the TransformersZeroShotTextRouter. This method routes the text to one of the different edges based on which label it has been categorized into. :param text: A str to route to one of the different edges. + :returns: + A dictionary with the label as key and the text as value. + + :raises TypeError: + If the input is not a str. + :raises RuntimeError: + If the pipeline has not been loaded. """ if self.pipeline is None: raise RuntimeError( From 80f673407bee5d507186bbcb7dab2d476b31cbb3 Mon Sep 17 00:00:00 2001 From: Sebastian Husch Lee Date: Thu, 14 Mar 2024 15:27:27 +0100 Subject: [PATCH 18/20] Make example runnable --- .../components/routers/zero_shot_text_router.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/haystack/components/routers/zero_shot_text_router.py b/haystack/components/routers/zero_shot_text_router.py index 515cb4af2f..a3f4eeaca1 100644 --- a/haystack/components/routers/zero_shot_text_router.py +++ b/haystack/components/routers/zero_shot_text_router.py @@ -28,19 +28,18 @@ class TransformersZeroShotTextRouter: queries to a BM25 retriever: ```python - document_store = InMemoryDocumentStore() + from haystack.core.pipeline import Pipeline + from haystack.components.routers import TransformersZeroShotTextRouter + from haystack.components.embedders import SentenceTransformersTextEmbedder + p = Pipeline() p.add_component(instance=TransformersZeroShotTextRouter(labels=["passage", "query"]), name="text_router") p.add_component( - instance=SentenceTransformersTextEmbedder( - document_store=document_store, model="intfloat/e5-base-v2", prefix="passage: " - ), + instance=SentenceTransformersTextEmbedder(model="intfloat/e5-base-v2", prefix="passage: "), name="passage_embedder" ) p.add_component( - instance=SentenceTransformersTextEmbedder( - document_store=document_store, model="intfloat/e5-base-v2", prefix="query: " - ), + instance=SentenceTransformersTextEmbedder(model="intfloat/e5-base-v2", prefix="query: "), name="query_embedder" ) p.connect("text_router.passage", "passage_embedder.text") From a4612fd212b7745849ba603b69cfc20e0ae5af17 Mon Sep 17 00:00:00 2001 From: Sebastian Husch Lee Date: Fri, 15 Mar 2024 10:19:14 +0100 Subject: [PATCH 19/20] Rename to huggingface_pipeline_kwargs --- .../routers/zero_shot_text_router.py | 50 ++++++++++++------- .../routers/test_zero_shot_text_router.py | 6 +-- 2 files changed, 34 insertions(+), 22 deletions(-) diff --git a/haystack/components/routers/zero_shot_text_router.py b/haystack/components/routers/zero_shot_text_router.py index a3f4eeaca1..7c9928b34d 100644 --- a/haystack/components/routers/zero_shot_text_router.py +++ b/haystack/components/routers/zero_shot_text_router.py @@ -24,13 +24,19 @@ class TransformersZeroShotTextRouter: This is useful for routing queries to different models in a pipeline depending on their categorization. The set of labels to be used for categorization can be specified. - Example usage in a retrieval pipeline that passes question-like queries to an embedding retriever and keyword-like - queries to a BM25 retriever: + Example usage in a retrieval pipeline that passes question-like queries to an embedding retriever optimized for + query-passage retrieval and passage-like queries to an embedding retriever optimized for passage-passage retrieval. ```python + from haystack import Document + from haystack.document_stores.in_memory import InMemoryDocumentStore from haystack.core.pipeline import Pipeline from haystack.components.routers import TransformersZeroShotTextRouter from haystack.components.embedders import SentenceTransformersTextEmbedder + from haystack.components.retrievers import InMemoryEmbeddingRetriever + + document_store = InMemoryDocumentStore() + document_store.write_documents([Document(text="The capital of Germany is Berlin.")]) p = Pipeline() p.add_component(instance=TransformersZeroShotTextRouter(labels=["passage", "query"]), name="text_router") @@ -42,14 +48,20 @@ class TransformersZeroShotTextRouter: instance=SentenceTransformersTextEmbedder(model="intfloat/e5-base-v2", prefix="query: "), name="query_embedder" ) + p.add_component( + instance=InMemoryEmbeddingRetriever(document_store=document_store, embedding_field="passage_embedding"), + ) + p.connect("text_router.passage", "passage_embedder.text") p.connect("text_router.query", "query_embedder.text") + # Query Example - p.run({"text_router": {"text": "What's your query?"}}) + p.run({"text_router": {"text": "What is the capital of Germany?"}}) + # Passage Example p.run({ "text_router":{ - "text": "Last week I upgraded my iOS version and ever since then my phone has been overheating whenever I use your app." + "text": "The capital of France is Paris.", } }) ``` @@ -62,7 +74,7 @@ def __init__( model: str = "MoritzLaurer/deberta-v3-base-zeroshot-v1.1-all-33", device: Optional[ComponentDevice] = None, token: Optional[Secret] = Secret.from_env_var("HF_API_TOKEN", strict=False), - pipeline_kwargs: Optional[Dict[str, Any]] = None, + huggingface_pipeline_kwargs: Optional[Dict[str, Any]] = None, ): """ :param labels: The set of possible class labels to classify each sequence into. Can be a single label, @@ -73,11 +85,11 @@ def __init__( doing a softmax of the entailment score vs. the contradiction score. :param model: The name or path of a Hugging Face model for zero-shot text classification. :param device: The device on which the model is loaded. If `None`, the default device is automatically - selected. If a device/device map is specified in `pipeline_kwargs`, it overrides this parameter. + selected. If a device/device map is specified in `huggingface_pipeline_kwargs`, it overrides this parameter. :param token: The API token used to download private models from Hugging Face. If this parameter is set to `True`, the token generated when running `transformers-cli login` (stored in ~/.huggingface) is used. - :param pipeline_kwargs: Dictionary containing keyword arguments used to initialize the + :param huggingface_pipeline_kwargs: Dictionary containing keyword arguments used to initialize the Hugging Face pipeline for zero shot text classification. """ torch_and_transformers_import.check() @@ -87,31 +99,31 @@ def __init__( self.multi_label = multi_label component.set_output_types(self, **{label: str for label in labels}) - pipeline_kwargs = resolve_hf_pipeline_kwargs( - huggingface_pipeline_kwargs=pipeline_kwargs or {}, + huggingface_pipeline_kwargs = resolve_hf_pipeline_kwargs( + huggingface_pipeline_kwargs=huggingface_pipeline_kwargs or {}, model=model, task="zero-shot-classification", supported_tasks=["zero-shot-classification"], device=device, token=token, ) - self.pipeline_kwargs = pipeline_kwargs + self.huggingface_pipeline_kwargs = huggingface_pipeline_kwargs self.pipeline = None def _get_telemetry_data(self) -> Dict[str, Any]: """ Data that is sent to Posthog for usage analytics. """ - if isinstance(self.pipeline_kwargs["model"], str): - return {"model": self.pipeline_kwargs["model"]} - return {"model": f"[object of type {type(self.pipeline_kwargs['model'])}]"} + if isinstance(self.huggingface_pipeline_kwargs["model"], str): + return {"model": self.huggingface_pipeline_kwargs["model"]} + return {"model": f"[object of type {type(self.huggingface_pipeline_kwargs['model'])}]"} def warm_up(self): """ Initializes the component. """ if self.pipeline is None: - self.pipeline = pipeline(**self.pipeline_kwargs) + self.pipeline = pipeline(**self.huggingface_pipeline_kwargs) def to_dict(self) -> Dict[str, Any]: """ @@ -123,14 +135,14 @@ def to_dict(self) -> Dict[str, Any]: serialization_dict = default_to_dict( self, labels=self.labels, - pipeline_kwargs=self.pipeline_kwargs, + huggingface_pipeline_kwargs=self.huggingface_pipeline_kwargs, token=self.token.to_dict() if self.token else None, ) - pipeline_kwargs = serialization_dict["init_parameters"]["pipeline_kwargs"] - pipeline_kwargs.pop("token", None) + huggingface_pipeline_kwargs = serialization_dict["init_parameters"]["huggingface_pipeline_kwargs"] + huggingface_pipeline_kwargs.pop("token", None) - serialize_hf_model_kwargs(pipeline_kwargs) + serialize_hf_model_kwargs(huggingface_pipeline_kwargs) return serialization_dict @classmethod @@ -144,7 +156,7 @@ def from_dict(cls, data: Dict[str, Any]) -> "TransformersZeroShotTextRouter": Deserialized component. """ deserialize_secrets_inplace(data["init_parameters"], keys=["token"]) - deserialize_hf_model_kwargs(data["init_parameters"]["pipeline_kwargs"]) + deserialize_hf_model_kwargs(data["init_parameters"]["huggingface_pipeline_kwargs"]) return default_from_dict(cls, data) @component.output_types(documents=Dict[str, str]) diff --git a/test/components/routers/test_zero_shot_text_router.py b/test/components/routers/test_zero_shot_text_router.py index f873b427c1..6472134979 100644 --- a/test/components/routers/test_zero_shot_text_router.py +++ b/test/components/routers/test_zero_shot_text_router.py @@ -15,7 +15,7 @@ def test_to_dict(self): "init_parameters": { "labels": ["query", "passage"], "token": {"env_vars": ["HF_API_TOKEN"], "strict": False, "type": "env_var"}, - "pipeline_kwargs": { + "huggingface_pipeline_kwargs": { "model": "MoritzLaurer/deberta-v3-base-zeroshot-v1.1-all-33", "device": ComponentDevice.resolve_device(None).to_hf(), "task": "zero-shot-classification", @@ -29,7 +29,7 @@ def test_from_dict(self): "init_parameters": { "labels": ["query", "passage"], "token": {"env_vars": ["HF_API_TOKEN"], "strict": False, "type": "env_var"}, - "pipeline_kwargs": { + "huggingface_pipeline_kwargs": { "model": "MoritzLaurer/deberta-v3-base-zeroshot-v1.1-all-33", "device": ComponentDevice.resolve_device(None).to_hf(), "task": "zero-shot-classification", @@ -41,7 +41,7 @@ def test_from_dict(self): assert component.labels == ["query", "passage"] assert component.pipeline is None assert component.token == Secret.from_dict({"env_vars": ["HF_API_TOKEN"], "strict": False, "type": "env_var"}) - assert component.pipeline_kwargs == { + assert component.huggingface_pipeline_kwargs == { "model": "MoritzLaurer/deberta-v3-base-zeroshot-v1.1-all-33", "device": ComponentDevice.resolve_device(None).to_hf(), "task": "zero-shot-classification", From ae67e15255d9133e209f3c1ddc45ad555d3ca3dd Mon Sep 17 00:00:00 2001 From: Sebastian Husch Lee Date: Fri, 15 Mar 2024 13:17:41 +0100 Subject: [PATCH 20/20] Fix example --- .../routers/zero_shot_text_router.py | 37 ++++++++++++++++--- 1 file changed, 31 insertions(+), 6 deletions(-) diff --git a/haystack/components/routers/zero_shot_text_router.py b/haystack/components/routers/zero_shot_text_router.py index 7c9928b34d..0f08a8fb49 100644 --- a/haystack/components/routers/zero_shot_text_router.py +++ b/haystack/components/routers/zero_shot_text_router.py @@ -24,19 +24,35 @@ class TransformersZeroShotTextRouter: This is useful for routing queries to different models in a pipeline depending on their categorization. The set of labels to be used for categorization can be specified. - Example usage in a retrieval pipeline that passes question-like queries to an embedding retriever optimized for - query-passage retrieval and passage-like queries to an embedding retriever optimized for passage-passage retrieval. + Example usage in a retrieval pipeline that passes question-like queries to a text embedder optimized for + query-passage retrieval and passage-like queries to a text embedder optimized for passage-passage retrieval. ```python from haystack import Document from haystack.document_stores.in_memory import InMemoryDocumentStore from haystack.core.pipeline import Pipeline from haystack.components.routers import TransformersZeroShotTextRouter - from haystack.components.embedders import SentenceTransformersTextEmbedder + from haystack.components.embedders import SentenceTransformersTextEmbedder, SentenceTransformersDocumentEmbedder from haystack.components.retrievers import InMemoryEmbeddingRetriever document_store = InMemoryDocumentStore() - document_store.write_documents([Document(text="The capital of Germany is Berlin.")]) + doc_embedder = SentenceTransformersDocumentEmbedder(model="intfloat/e5-base-v2") + doc_embedder.warm_up() + docs = [ + Document( + content="Germany, officially the Federal Republic of Germany, is a country in the western region of " + "Central Europe. The nation's capital and most populous city is Berlin and its main financial centre " + "is Frankfurt; the largest urban area is the Ruhr." + ), + Document( + content="France, officially the French Republic, is a country located primarily in Western Europe. " + "France is a unitary semi-presidential republic with its capital in Paris, the country's largest city " + "and main cultural and commercial centre; other major urban areas include Marseille, Lyon, Toulouse, " + "Lille, Bordeaux, Strasbourg, Nantes and Nice." + ) + ] + docs_with_embeddings = doc_embedder.run(docs) + document_store.write_documents(docs_with_embeddings["documents"]) p = Pipeline() p.add_component(instance=TransformersZeroShotTextRouter(labels=["passage", "query"]), name="text_router") @@ -49,11 +65,18 @@ class TransformersZeroShotTextRouter: name="query_embedder" ) p.add_component( - instance=InMemoryEmbeddingRetriever(document_store=document_store, embedding_field="passage_embedding"), + instance=InMemoryEmbeddingRetriever(document_store=document_store), + name="query_retriever" + ) + p.add_component( + instance=InMemoryEmbeddingRetriever(document_store=document_store), + name="passage_retriever" ) p.connect("text_router.passage", "passage_embedder.text") + p.connect("passage_embedder.embedding", "passage_retriever.query_embedding") p.connect("text_router.query", "query_embedder.text") + p.connect("query_embedder.embedding", "query_retriever.query_embedding") # Query Example p.run({"text_router": {"text": "What is the capital of Germany?"}}) @@ -61,7 +84,9 @@ class TransformersZeroShotTextRouter: # Passage Example p.run({ "text_router":{ - "text": "The capital of France is Paris.", + "text": "The United Kingdom of Great Britain and Northern Ireland, commonly known as the "\ + "United Kingdom (UK) or Britain, is a country in Northwestern Europe, off the north-western coast of "\ + "the continental mainland." } }) ```