Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add Zero Shot Transformers Text Router #7018

Merged
merged 25 commits into from
Mar 15, 2024
Merged
Show file tree
Hide file tree
Changes from 23 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 11 additions & 28 deletions haystack/components/generators/hugging_face_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,9 @@
SUPPORTED_TASKS = ["text-generation", "text2text-generation"]

with LazyImport(message="Run 'pip install transformers[torch]'") as transformers_import:
from huggingface_hub import model_info
from transformers import StoppingCriteriaList, pipeline

from haystack.utils.hf import StopWordsCriteria # pylint: disable=ungrouped-imports
from haystack.utils.hf import StopWordsCriteria, resolve_hf_pipeline_kwargs # pylint: disable=ungrouped-imports


@component
Expand Down Expand Up @@ -84,37 +83,21 @@ def __init__(
"""
transformers_import.check()

huggingface_pipeline_kwargs = huggingface_pipeline_kwargs or {}
self.token = token
generation_kwargs = generation_kwargs or {}

self.token = token
token = token.resolve_value() if token else None

# check if the huggingface_pipeline_kwargs contain the essential parameters
# otherwise, populate them with values from other init parameters
huggingface_pipeline_kwargs.setdefault("model", model)
huggingface_pipeline_kwargs.setdefault("token", token)

device = ComponentDevice.resolve_device(device)
device.update_hf_kwargs(huggingface_pipeline_kwargs, overwrite=False)

# task identification and validation
if task is None:
if "task" in huggingface_pipeline_kwargs:
task = huggingface_pipeline_kwargs["task"]
elif isinstance(huggingface_pipeline_kwargs["model"], str):
task = model_info(
huggingface_pipeline_kwargs["model"], token=huggingface_pipeline_kwargs["token"]
).pipeline_tag

if task not in SUPPORTED_TASKS:
raise ValueError(
f"Task '{task}' is not supported. " f"The supported tasks are: {', '.join(SUPPORTED_TASKS)}."
)
huggingface_pipeline_kwargs["task"] = task
huggingface_pipeline_kwargs = resolve_hf_pipeline_kwargs(
huggingface_pipeline_kwargs=huggingface_pipeline_kwargs or {},
model=model,
task=task,
supported_tasks=SUPPORTED_TASKS,
device=device,
token=token,
)

# if not specified, set return_full_text to False for text-generation
# only generated text is returned (excluding prompt)
task = huggingface_pipeline_kwargs["task"]
if task == "text-generation":
generation_kwargs.setdefault("return_full_text", False)

Expand Down
9 changes: 8 additions & 1 deletion haystack/components/routers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,12 @@
from haystack.components.routers.file_type_router import FileTypeRouter
from haystack.components.routers.metadata_router import MetadataRouter
from haystack.components.routers.text_language_router import TextLanguageRouter
from haystack.components.routers.zero_shot_text_router import TransformersZeroShotTextRouter

__all__ = ["FileTypeRouter", "MetadataRouter", "TextLanguageRouter", "ConditionalRouter"]
__all__ = [
"FileTypeRouter",
"MetadataRouter",
"TextLanguageRouter",
"ConditionalRouter",
"TransformersZeroShotTextRouter",
]
177 changes: 177 additions & 0 deletions haystack/components/routers/zero_shot_text_router.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
from typing import Any, Dict, List, Optional

from haystack import component, default_from_dict, default_to_dict, logging
from haystack.lazy_imports import LazyImport
from haystack.utils import ComponentDevice, Secret, deserialize_secrets_inplace

logger = logging.getLogger(__name__)


with LazyImport(message="Run 'pip install transformers[torch,sentencepiece]'") as torch_and_transformers_import:
from transformers import pipeline

from haystack.utils.hf import ( # pylint: disable=ungrouped-imports
deserialize_hf_model_kwargs,
resolve_hf_pipeline_kwargs,
serialize_hf_model_kwargs,
)


@component
class TransformersZeroShotTextRouter:
"""
Routes a text input onto different output connections depending on which label it has been categorized into.
This is useful for routing queries to different models in a pipeline depending on their categorization.
The set of labels to be used for categorization can be specified.

Example usage in a retrieval pipeline that passes question-like queries to an embedding retriever and keyword-like
queries to a BM25 retriever:

```python
from haystack.core.pipeline import Pipeline
from haystack.components.routers import TransformersZeroShotTextRouter
from haystack.components.embedders import SentenceTransformersTextEmbedder

p = Pipeline()
p.add_component(instance=TransformersZeroShotTextRouter(labels=["passage", "query"]), name="text_router")
p.add_component(
instance=SentenceTransformersTextEmbedder(model="intfloat/e5-base-v2", prefix="passage: "),
name="passage_embedder"
)
p.add_component(
instance=SentenceTransformersTextEmbedder(model="intfloat/e5-base-v2", prefix="query: "),
name="query_embedder"
)
p.connect("text_router.passage", "passage_embedder.text")
p.connect("text_router.query", "query_embedder.text")
# Query Example
p.run({"text_router": {"text": "What's your query?"}})
# Passage Example
p.run({
"text_router":{
"text": "Last week I upgraded my iOS version and ever since then my phone has been overheating whenever I use your app."
}
})
```
"""

def __init__(
self,
labels: List[str],
multi_label: bool = False,
model: str = "MoritzLaurer/deberta-v3-base-zeroshot-v1.1-all-33",
device: Optional[ComponentDevice] = None,
token: Optional[Secret] = Secret.from_env_var("HF_API_TOKEN", strict=False),
pipeline_kwargs: Optional[Dict[str, Any]] = None,
):
"""
:param labels: The set of possible class labels to classify each sequence into. Can be a single label,
a string of comma-separated labels, or a list of labels.
:param multi_label: Whether or not multiple candidate labels can be true.
If False, the scores are normalized such that the sum of the label likelihoods for each sequence is 1.
If True, the labels are considered independent and probabilities are normalized for each candidate by
doing a softmax of the entailment score vs. the contradiction score.
:param model: The name or path of a Hugging Face model for zero-shot text classification.
:param device: The device on which the model is loaded. If `None`, the default device is automatically
selected. If a device/device map is specified in `pipeline_kwargs`, it overrides this parameter.
:param token: The API token used to download private models from Hugging Face.
If this parameter is set to `True`, the token generated when running
`transformers-cli login` (stored in ~/.huggingface) is used.
:param pipeline_kwargs: Dictionary containing keyword arguments used to initialize the
Hugging Face pipeline for zero shot text classification.
"""
torch_and_transformers_import.check()

self.token = token
self.labels = labels
self.multi_label = multi_label
component.set_output_types(self, **{label: str for label in labels})

pipeline_kwargs = resolve_hf_pipeline_kwargs(
huggingface_pipeline_kwargs=pipeline_kwargs or {},
model=model,
task="zero-shot-classification",
supported_tasks=["zero-shot-classification"],
device=device,
token=token,
)
self.pipeline_kwargs = pipeline_kwargs
self.pipeline = None

def _get_telemetry_data(self) -> Dict[str, Any]:
"""
Data that is sent to Posthog for usage analytics.
"""
if isinstance(self.pipeline_kwargs["model"], str):
return {"model": self.pipeline_kwargs["model"]}
return {"model": f"[object of type {type(self.pipeline_kwargs['model'])}]"}

def warm_up(self):
"""
Initializes the component.
"""
if self.pipeline is None:
self.pipeline = pipeline(**self.pipeline_kwargs)

def to_dict(self) -> Dict[str, Any]:
"""
Serializes the component to a dictionary.

:returns:
Dictionary with serialized data.
"""
serialization_dict = default_to_dict(
self,
labels=self.labels,
pipeline_kwargs=self.pipeline_kwargs,
token=self.token.to_dict() if self.token else None,
)

pipeline_kwargs = serialization_dict["init_parameters"]["pipeline_kwargs"]
pipeline_kwargs.pop("token", None)

serialize_hf_model_kwargs(pipeline_kwargs)
return serialization_dict

@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "TransformersZeroShotTextRouter":
"""
Deserializes the component from a dictionary.

:param data:
Dictionary to deserialize from.
:returns:
Deserialized component.
"""
deserialize_secrets_inplace(data["init_parameters"], keys=["token"])
deserialize_hf_model_kwargs(data["init_parameters"]["pipeline_kwargs"])
return default_from_dict(cls, data)

@component.output_types(documents=Dict[str, str])
def run(self, text: str):
"""
Run the TransformersZeroShotTextRouter. This method routes the text to one of the different edges based on which label
it has been categorized into.

:param text: A str to route to one of the different edges.
:returns:
A dictionary with the label as key and the text as value.

:raises TypeError:
If the input is not a str.
:raises RuntimeError:
If the pipeline has not been loaded.
"""
if self.pipeline is None:
raise RuntimeError(
"The zero-shot classification pipeline has not been loaded. Please call warm_up() before running."
)

if not isinstance(text, str):
raise TypeError("TransformersZeroShotTextRouter expects a str as input.")

prediction = self.pipeline(sequences=[text], candidate_labels=self.labels, multi_label=self.multi_label)
predicted_scores = prediction[0]["scores"]
max_score_index = max(range(len(predicted_scores)), key=predicted_scores.__getitem__)
label = prediction[0]["labels"][max_score_index]
return {label: text}
46 changes: 45 additions & 1 deletion haystack/utils/hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import torch

with LazyImport(message="Run 'pip install transformers'") as transformers_import:
from huggingface_hub import HfApi, InferenceClient
from huggingface_hub import HfApi, InferenceClient, model_info
from huggingface_hub.utils import RepositoryNotFoundError

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -98,6 +98,50 @@ def resolve_hf_device_map(device: Optional[ComponentDevice], model_kwargs: Optio
return model_kwargs


def resolve_hf_pipeline_kwargs(
huggingface_pipeline_kwargs: Dict[str, Any],
model: str,
task: Optional[str],
supported_tasks: List[str],
device: Optional[ComponentDevice],
token: Optional[Secret],
) -> Dict[str, Any]:
"""
Resolve the HuggingFace pipeline keyword arguments based on explicit user inputs.

:param huggingface_pipeline_kwargs: Dictionary containing keyword arguments used to initialize a
Hugging Face pipeline.
:param model: The name or path of a Hugging Face model for on the HuggingFace Hub.
:param task: The task for the Hugging Face pipeline.
:param supported_tasks: The list of supported tasks to check the task of the model against. If the task of the model
is not present within this list then a ValueError is thrown.
:param device: The device on which the model is loaded. If `None`, the default device is automatically
selected. If a device/device map is specified in `huggingface_pipeline_kwargs`, it overrides this parameter.
:param token: The token to use as HTTP bearer authorization for remote files.
If the token is also specified in the `huggingface_pipeline_kwargs`, this parameter will be ignored.
"""
transformers_import.check()

token = token.resolve_value() if token else None
# check if the huggingface_pipeline_kwargs contain the essential parameters
# otherwise, populate them with values from other init parameters
huggingface_pipeline_kwargs.setdefault("model", model)
huggingface_pipeline_kwargs.setdefault("token", token)

device = ComponentDevice.resolve_device(device)
device.update_hf_kwargs(huggingface_pipeline_kwargs, overwrite=False)

# task identification and validation
task = task or huggingface_pipeline_kwargs.get("task")
if task is None and isinstance(huggingface_pipeline_kwargs["model"], str):
task = model_info(huggingface_pipeline_kwargs["model"], token=huggingface_pipeline_kwargs["token"]).pipeline_tag

if task not in supported_tasks:
raise ValueError(f"Task '{task}' is not supported. " f"The supported tasks are: {', '.join(supported_tasks)}.")
huggingface_pipeline_kwargs["task"] = task
return huggingface_pipeline_kwargs


def list_inference_deployed_models(headers: Optional[Dict] = None) -> List[str]:
"""
List all currently deployed models on HF TGI free tier
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
---
features:
- |
Add a Zero Shot Text Router that uses an NLI model from HF to classify texts based on a set of provided labels and routes them based on the label they were classified with.
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@


class TestHuggingFaceLocalGenerator:
@patch("haystack.components.generators.hugging_face_local.model_info")
@patch("haystack.utils.hf.model_info")
def test_init_default(self, model_info_mock, monkeypatch):
monkeypatch.delenv("HF_API_TOKEN", raising=False)
model_info_mock.return_value.pipeline_tag = "text2text-generation"
Expand Down Expand Up @@ -73,7 +73,7 @@ def test_init_task_in_huggingface_pipeline_kwargs(self):
"device": ComponentDevice.resolve_device(None).to_hf(),
}

@patch("haystack.components.generators.hugging_face_local.model_info")
@patch("haystack.utils.hf.model_info")
def test_init_task_inferred_from_model_name(self, model_info_mock):
model_info_mock.return_value.pipeline_tag = "text2text-generation"
generator = HuggingFaceLocalGenerator(model="google/flan-t5-base", token=None)
Expand Down Expand Up @@ -137,7 +137,7 @@ def test_init_fails_with_both_stopwords_and_stoppingcriteria(self):
generation_kwargs={"stopping_criteria": "fake-stopping-criteria"},
)

@patch("haystack.components.generators.hugging_face_local.model_info")
@patch("haystack.utils.hf.model_info")
def test_to_dict_default(self, model_info_mock):
model_info_mock.return_value.pipeline_tag = "text2text-generation"

Expand Down
Loading