Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: TGI Generator refactoring #7412

Closed
wants to merge 10 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 49 additions & 69 deletions haystack/components/generators/hugging_face_tgi.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import warnings
from dataclasses import asdict
from typing import Any, Callable, Dict, Iterable, List, Optional
from urllib.parse import urlparse
Expand All @@ -6,45 +7,43 @@
from haystack.dataclasses import StreamingChunk
from haystack.lazy_imports import LazyImport
from haystack.utils import Secret, deserialize_callable, deserialize_secrets_inplace, serialize_callable
from haystack.utils.hf import HFModelType, check_generation_params, check_valid_model, list_inference_deployed_models
from haystack.utils.hf import HFModelType, check_generation_params, check_valid_model

with LazyImport(message="Run 'pip install \"huggingface_hub>=0.22.0\" transformers'") as transformers_import:
with LazyImport(message="Run 'pip install \"huggingface_hub>=0.22.0\"'") as huggingface_hub_import:
from huggingface_hub import (
InferenceClient,
TextGenerationOutput,
TextGenerationOutputToken,
TextGenerationStreamOutput,
)
from transformers import AutoTokenizer


logger = logging.getLogger(__name__)


# TODO: remove the default model in Haystack 2.3.0, as explained in the deprecation warning
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's instead open an issue for this (and the other deprecated changes) and add it to the 2.3.0 milestone (after creating it). We can add a link to this issue here.

DEFAULT_MODEL = "mistralai/Mistral-7B-v0.1"


@component
class HuggingFaceTGIGenerator:
"""
Enables text generation using HuggingFace Hub hosted non-chat LLMs.

This component is designed to seamlessly inference models deployed on the Text Generation Inference (TGI) backend.
You can use this component for LLMs hosted on Hugging Face inference endpoints, the rate-limited
Inference API tier.
You can use this component for LLMs hosted on Hugging Face inference endpoints.

Key Features and Compatibility:
- Primary Compatibility: designed to work seamlessly with any non-based model deployed using the TGI
- Primary Compatibility: designed to work seamlessly with models deployed using the TGI
framework. For more information on TGI, visit [text-generation-inference](https://github.com/huggingface/text-generation-inference)

- Hugging Face Inference Endpoints: Supports inference of TGI chat LLMs deployed on Hugging Face
- Hugging Face Inference Endpoints: Supports inference of LLMs deployed on Hugging Face
inference endpoints. For more details, refer to [inference-endpoints](https://huggingface.co/inference-endpoints)

- Inference API Support: supports inference of TGI LLMs hosted on the rate-limited Inference
- Inference API Support: supports inference of LLMs hosted on the rate-limited Inference
API tier. Learn more about the Inference API at [inference-api](https://huggingface.co/inference-api).
Discover available chat models using the following command: `wget -qO- https://api-inference.huggingface.co/framework/text-generation-inference | grep chat`
and simply use the model ID as the model parameter for this component. You'll also need to provide a valid
Hugging Face API token as the token parameter.
In this case, you need to provide a valid Hugging Face token.

- Custom TGI Endpoints: supports inference of TGI chat LLMs deployed on custom TGI endpoints. Anyone can
deploy their own TGI endpoint using the TGI framework. For more details, refer to [inference-endpoints](https://huggingface.co/inference-endpoints)

Input and Output Format:
- String Format: This component uses the str format for structuring both input and output,
Comment on lines 36 to 49
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we remove this "market'y"-sounding docstring and merge the 3 links into the sentence above, similar to:

This component can be used with the HuggingFace TGI framework, Inference Endpoints and Inference API

Expand All @@ -55,28 +54,26 @@ class HuggingFaceTGIGenerator:
from haystack.utils import Secret

client = HuggingFaceTGIGenerator(model="mistralai/Mistral-7B-v0.1", token=Secret.from_token("<your-api-key>"))
client.warm_up()
response = client.run("What's Natural Language Processing?", max_new_tokens=120)
print(response)
```

Or for LLMs hosted on paid https://huggingface.co/inference-endpoints endpoint, and/or your own custom TGI endpoint.
In these two cases, you'll need to provide the URL of the endpoint as well as a valid token:
In these two cases, you'll need to provide the URL of the endpoint.
For Inference Endpoints, you also need to provide a valid Hugging Face token.

```python
from haystack.components.generators import HuggingFaceTGIGenerator
client = HuggingFaceTGIGenerator(model="mistralai/Mistral-7B-v0.1",
url="<your-tgi-endpoint-url>",
client = HuggingFaceTGIGenerator(url="<your-tgi-endpoint-url>",
token=Secret.from_token("<your-api-key>"))
client.warm_up()
response = client.run("What's Natural Language Processing?")
print(response)
```
"""

def __init__(
self,
model: str = "mistralai/Mistral-7B-v0.1",
model: Optional[str] = None,
url: Optional[str] = None,
token: Optional[Secret] = Secret.from_env_var("HF_API_TOKEN", strict=False),
generation_kwargs: Optional[Dict[str, Any]] = None,
Expand All @@ -87,10 +84,12 @@ def __init__(
Initialize the HuggingFaceTGIGenerator instance.

:param model:
A string representing the model id on HF Hub. Default is "mistralai/Mistral-7B-v0.1".
An optional string representing the model id on HF Hub.
If not provided, the `url` parameter must be set to a valid TGI endpoint.
:param url:
An optional string representing the URL of the TGI endpoint. If the url is not provided, check if the model
is deployed on the free tier of the HF inference API.
An optional string representing the URL of the TGI endpoint.
If not provided, the `model` parameter must be set to a valid model id and the Hugging Face Inference API
will be used.
:param token: The HuggingFace token to use as HTTP bearer authorization
You can find your HF token in your [account settings](https://huggingface.co/settings/tokens)
:param generation_kwargs:
Expand All @@ -100,15 +99,25 @@ def __init__(
:param stop_words: An optional list of strings representing the stop words.
:param streaming_callback: An optional callable for handling streaming responses.
"""
transformers_import.check()
huggingface_hub_import.check()

if not model and not url:
warnings.warn(
f"Neither `model` nor `url` is provided. The component will use the default model: {DEFAULT_MODEL}. "
"This behavior is deprecated and will be removed in Haystack 2.3.0.",
DeprecationWarning,
)
model = DEFAULT_MODEL
elif model and url:
logger.warning("Both `model` and `url` are provided. The `model` parameter will be ignored. ")

if url:
r = urlparse(url)
is_valid_url = all([r.scheme in ["http", "https"], r.netloc])
if not is_valid_url:
raise ValueError(f"Invalid TGI endpoint URL provided: {url}")

check_valid_model(model, HFModelType.GENERATION, token)
elif model:
check_valid_model(model, HFModelType.GENERATION, token)

# handle generation kwargs setup
generation_kwargs = generation_kwargs.copy() if generation_kwargs else {}
Expand All @@ -121,28 +130,14 @@ def __init__(
self.url = url
self.token = token
self.generation_kwargs = generation_kwargs
self.client = InferenceClient(url or model, token=token.resolve_value() if token else None)
self.streaming_callback = streaming_callback
self.tokenizer = None
self._client = InferenceClient(url or model, token=token.resolve_value() if token else None)

def warm_up(self) -> None:
"""
Initializes the component.
"""

# is this user using HF free tier inference API?
if self.model and not self.url:
deployed_models = list_inference_deployed_models()
# Determine if the specified model is deployed in the free tier.
if self.model not in deployed_models:
raise ValueError(
f"The model {self.model} is not deployed on the free tier of the HF inference API. "
"To use free tier models provide the model ID and the token. Valid models are: "
f"{deployed_models}"
)

self.tokenizer = AutoTokenizer.from_pretrained(
self.model, token=self.token.resolve_value() if self.token else None
warnings.warn(
"The `warm_up` method of `HuggingFaceTGIGenerator` does nothing and is momentarily maintained to ensure backward compatibility. "
"It is deprecated and will be removed in Haystack 2.3.0.",
DeprecationWarning,
)

def to_dict(self) -> Dict[str, Any]:
Expand Down Expand Up @@ -203,21 +198,16 @@ def run(self, prompt: str, generation_kwargs: Optional[Dict[str, Any]] = None):
num_responses = generation_kwargs.pop("n", 1)
generation_kwargs.setdefault("stop_sequences", []).extend(generation_kwargs.pop("stop_words", []))

if self.tokenizer is None:
raise RuntimeError("Please call warm_up() before running LLM inference.")

prompt_token_count = len(self.tokenizer.encode(prompt, add_special_tokens=False))

if self.streaming_callback:
if num_responses > 1:
raise ValueError("Cannot stream multiple responses, please set n=1.")

return self._run_streaming(prompt, prompt_token_count, generation_kwargs)
return self._run_streaming(prompt, generation_kwargs)

return self._run_non_streaming(prompt, prompt_token_count, num_responses, generation_kwargs)
return self._run_non_streaming(prompt, num_responses, generation_kwargs)

def _run_streaming(self, prompt: str, prompt_token_count: int, generation_kwargs: Dict[str, Any]):
res_chunk: Iterable[TextGenerationStreamOutput] = self.client.text_generation(
def _run_streaming(self, prompt: str, generation_kwargs: Dict[str, Any]):
res_chunk: Iterable[TextGenerationStreamOutput] = self._client.text_generation(
prompt, details=True, stream=True, **generation_kwargs
)
chunks: List[StreamingChunk] = []
Expand All @@ -232,32 +222,22 @@ def _run_streaming(self, prompt: str, prompt_token_count: int, generation_kwargs
self.streaming_callback(stream_chunk) # type: ignore # streaming_callback is not None (verified in the run method)
metadata = {
"finish_reason": chunks[-1].meta.get("finish_reason", None),
"model": self.client.model,
"usage": {
"completion_tokens": chunks[-1].meta.get("generated_tokens", 0),
"prompt_tokens": prompt_token_count,
"total_tokens": prompt_token_count + chunks[-1].meta.get("generated_tokens", 0),
},
"model": self._client.model,
"usage": {"completion_tokens": chunks[-1].meta.get("generated_tokens", 0)},
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's keep the keys with values of zero until 2.3.0.

}
return {"replies": ["".join([chunk.content for chunk in chunks])], "meta": [metadata]}

def _run_non_streaming(
self, prompt: str, prompt_token_count: int, num_responses: int, generation_kwargs: Dict[str, Any]
):
def _run_non_streaming(self, prompt: str, num_responses: int, generation_kwargs: Dict[str, Any]):
responses: List[str] = []
all_metadata: List[Dict[str, Any]] = []
for _i in range(num_responses):
tgr: TextGenerationOutput = self.client.text_generation(prompt, details=True, **generation_kwargs)
tgr: TextGenerationOutput = self._client.text_generation(prompt, details=True, **generation_kwargs)
all_metadata.append(
{
"model": self.client.model,
"model": self._client.model,
"index": _i,
"finish_reason": tgr.details.finish_reason,
"usage": {
"completion_tokens": len(tgr.details.tokens),
"prompt_tokens": prompt_token_count,
"total_tokens": prompt_token_count + len(tgr.details.tokens),
},
"usage": {"completion_tokens": len(tgr.details.tokens)},
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above.

}
)
responses.append(tgr.generated_text)
Expand Down
11 changes: 11 additions & 0 deletions releasenotes/notes/tgi-refactoring-62885781f81e18d1.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
---
enhancements:
- |
Improve the HuggingFaceTGIGenerator component.
- Lighter: the component only depends on the `huggingface_hub` library.
- If initialized with an appropriate `url` parameter, the component can run on a local network and does not require internet access.
deprecations:
- |
- The HuggingFaceTGIGenerator component requires specifying either a `url` or `model` parameter.
Starting from Haystack 2.3.0, the component will raise an error if neither parameter is provided.
- The `warm_up` method of the HuggingFaceTGIGenerator component is deprecated and will be removed in 2.3.0 release.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should also mention the removal of the keys in the usage dict.

57 changes: 17 additions & 40 deletions test/components/generators/test_hugging_face_tgi.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,12 @@

import pytest
from huggingface_hub import TextGenerationOutputToken, TextGenerationStreamDetails, TextGenerationStreamOutput
from huggingface_hub.utils import RepositoryNotFoundError

from haystack.components.generators import HuggingFaceTGIGenerator
from haystack.components.generators.hugging_face_tgi import DEFAULT_MODEL, HuggingFaceTGIGenerator
from haystack.dataclasses import StreamingChunk
from haystack.utils.auth import Secret


@pytest.fixture
def mock_list_inference_deployed_models():
with patch(
"haystack.components.generators.hugging_face_tgi.list_inference_deployed_models",
MagicMock(
return_value=["HuggingFaceH4/zephyr-7b-alpha", "HuggingFaceH4/zephyr-7b-alpha", "mistralai/Mistral-7B-v0.1"]
),
) as mock:
yield mock


@pytest.fixture
def mock_check_valid_model():
with patch(
Expand Down Expand Up @@ -68,13 +56,12 @@ def test_initialize_with_valid_model_and_generation_parameters(self, mock_check_
**{"stop_sequences": ["stop"]},
**{"max_new_tokens": 512},
}
assert generator.tokenizer is None
assert generator.client is not None
assert generator._client is not None
assert generator.streaming_callback == streaming_callback

def test_to_dict(self, mock_check_valid_model):
# Initialize the HuggingFaceRemoteGenerator object with valid parameters
generator = HuggingFaceTGIGenerator(
model="mistralai/Mistral-7B-v0.1",
token=Secret.from_env_var("ENV_VAR", strict=False),
generation_kwargs={"n": 5},
stop_words=["stop", "words"],
Expand Down Expand Up @@ -110,14 +97,13 @@ def test_initialize_with_invalid_url(self, mock_check_valid_model):
with pytest.raises(ValueError):
HuggingFaceTGIGenerator(model="mistralai/Mistral-7B-v0.1", url="invalid_url")

def test_initialize_with_url_but_invalid_model(self, mock_check_valid_model):
# When custom TGI endpoint is used via URL, model must be provided and valid HuggingFace Hub model id
mock_check_valid_model.side_effect = RepositoryNotFoundError("Invalid model id")
with pytest.raises(RepositoryNotFoundError):
HuggingFaceTGIGenerator(model="invalid_model_id", url="https://some_chat_model.com")
def test_initialize_without_model_or_url(self, mock_check_valid_model):
generator = HuggingFaceTGIGenerator(model=None, url=None)

assert generator.model == DEFAULT_MODEL

def test_generate_text_response_with_valid_prompt_and_generation_parameters(
self, mock_check_valid_model, mock_auto_tokenizer, mock_text_generation, mock_list_inference_deployed_models
self, mock_check_valid_model, mock_text_generation
):
model = "mistralai/Mistral-7B-v0.1"

Expand All @@ -131,7 +117,6 @@ def test_generate_text_response_with_valid_prompt_and_generation_parameters(
stop_words=stop_words,
streaming_callback=streaming_callback,
)
generator.warm_up()

prompt = "Hello, how are you?"
response = generator.run(prompt)
Expand All @@ -151,7 +136,7 @@ def test_generate_text_response_with_valid_prompt_and_generation_parameters(
assert [isinstance(reply, str) for reply in response["replies"]]

def test_generate_multiple_text_responses_with_valid_prompt_and_generation_parameters(
self, mock_check_valid_model, mock_auto_tokenizer, mock_text_generation, mock_list_inference_deployed_models
self, mock_check_valid_model, mock_text_generation
):
model = "mistralai/Mistral-7B-v0.1"
generation_kwargs = {"n": 3}
Expand All @@ -164,7 +149,6 @@ def test_generate_multiple_text_responses_with_valid_prompt_and_generation_param
stop_words=stop_words,
streaming_callback=streaming_callback,
)
generator.warm_up()

prompt = "Hello, how are you?"
response = generator.run(prompt)
Expand Down Expand Up @@ -201,11 +185,8 @@ def test_initialize_with_invalid_model(self, mock_check_valid_model):
streaming_callback=streaming_callback,
)

def test_generate_text_with_stop_words(
self, mock_check_valid_model, mock_auto_tokenizer, mock_text_generation, mock_list_inference_deployed_models
):
generator = HuggingFaceTGIGenerator()
generator.warm_up()
def test_generate_text_with_stop_words(self, mock_check_valid_model, mock_text_generation):
generator = HuggingFaceTGIGenerator("HuggingFaceH4/zephyr-7b-alpha")

# Generate text response with stop words
response = generator.run("How are you?", generation_kwargs={"stop_words": ["stop", "words"]})
Expand All @@ -226,11 +207,8 @@ def test_generate_text_with_stop_words(
assert len(response["meta"]) > 0
assert [isinstance(reply, dict) for reply in response["replies"]]

def test_generate_text_with_custom_generation_parameters(
self, mock_check_valid_model, mock_auto_tokenizer, mock_text_generation, mock_list_inference_deployed_models
):
generator = HuggingFaceTGIGenerator()
generator.warm_up()
def test_generate_text_with_custom_generation_parameters(self, mock_check_valid_model, mock_text_generation):
generator = HuggingFaceTGIGenerator("HuggingFaceH4/zephyr-7b-alpha")

generation_kwargs = {"temperature": 0.8, "max_new_tokens": 100}
response = generator.run("How are you?", generation_kwargs=generation_kwargs)
Expand All @@ -252,9 +230,7 @@ def test_generate_text_with_custom_generation_parameters(
assert len(response["meta"]) > 0
assert [isinstance(reply, str) for reply in response["replies"]]

def test_generate_text_with_streaming_callback(
self, mock_check_valid_model, mock_auto_tokenizer, mock_text_generation, mock_list_inference_deployed_models
):
def test_generate_text_with_streaming_callback(self, mock_check_valid_model, mock_text_generation):
streaming_call_count = 0

# Define the streaming callback function
Expand All @@ -264,8 +240,9 @@ def streaming_callback_fn(chunk: StreamingChunk):
assert isinstance(chunk, StreamingChunk)

# Create an instance of HuggingFaceRemoteGenerator
generator = HuggingFaceTGIGenerator(streaming_callback=streaming_callback_fn)
generator.warm_up()
generator = HuggingFaceTGIGenerator(
model="HuggingFaceH4/zephyr-7b-alpha", streaming_callback=streaming_callback_fn
)

# Create a fake streamed response
# Don't remove self
Expand Down