-
Notifications
You must be signed in to change notification settings - Fork 2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
refactor: TGI Generator refactoring #7412
Changes from all commits
4969ae4
cebc8a9
75a4e99
f3b3d59
647fbcb
b8f0195
f1effa1
6115b18
6897afb
04ace00
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
import warnings | ||
from dataclasses import asdict | ||
from typing import Any, Callable, Dict, Iterable, List, Optional | ||
from urllib.parse import urlparse | ||
|
@@ -6,45 +7,43 @@ | |
from haystack.dataclasses import StreamingChunk | ||
from haystack.lazy_imports import LazyImport | ||
from haystack.utils import Secret, deserialize_callable, deserialize_secrets_inplace, serialize_callable | ||
from haystack.utils.hf import HFModelType, check_generation_params, check_valid_model, list_inference_deployed_models | ||
from haystack.utils.hf import HFModelType, check_generation_params, check_valid_model | ||
|
||
with LazyImport(message="Run 'pip install \"huggingface_hub>=0.22.0\" transformers'") as transformers_import: | ||
with LazyImport(message="Run 'pip install \"huggingface_hub>=0.22.0\"'") as huggingface_hub_import: | ||
from huggingface_hub import ( | ||
InferenceClient, | ||
TextGenerationOutput, | ||
TextGenerationOutputToken, | ||
TextGenerationStreamOutput, | ||
) | ||
from transformers import AutoTokenizer | ||
|
||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
# TODO: remove the default model in Haystack 2.3.0, as explained in the deprecation warning | ||
DEFAULT_MODEL = "mistralai/Mistral-7B-v0.1" | ||
|
||
|
||
@component | ||
class HuggingFaceTGIGenerator: | ||
""" | ||
Enables text generation using HuggingFace Hub hosted non-chat LLMs. | ||
|
||
This component is designed to seamlessly inference models deployed on the Text Generation Inference (TGI) backend. | ||
You can use this component for LLMs hosted on Hugging Face inference endpoints, the rate-limited | ||
Inference API tier. | ||
You can use this component for LLMs hosted on Hugging Face inference endpoints. | ||
|
||
Key Features and Compatibility: | ||
- Primary Compatibility: designed to work seamlessly with any non-based model deployed using the TGI | ||
- Primary Compatibility: designed to work seamlessly with models deployed using the TGI | ||
framework. For more information on TGI, visit [text-generation-inference](https://github.com/huggingface/text-generation-inference) | ||
|
||
- Hugging Face Inference Endpoints: Supports inference of TGI chat LLMs deployed on Hugging Face | ||
- Hugging Face Inference Endpoints: Supports inference of LLMs deployed on Hugging Face | ||
inference endpoints. For more details, refer to [inference-endpoints](https://huggingface.co/inference-endpoints) | ||
|
||
- Inference API Support: supports inference of TGI LLMs hosted on the rate-limited Inference | ||
- Inference API Support: supports inference of LLMs hosted on the rate-limited Inference | ||
API tier. Learn more about the Inference API at [inference-api](https://huggingface.co/inference-api). | ||
Discover available chat models using the following command: `wget -qO- https://api-inference.huggingface.co/framework/text-generation-inference | grep chat` | ||
and simply use the model ID as the model parameter for this component. You'll also need to provide a valid | ||
Hugging Face API token as the token parameter. | ||
In this case, you need to provide a valid Hugging Face token. | ||
|
||
- Custom TGI Endpoints: supports inference of TGI chat LLMs deployed on custom TGI endpoints. Anyone can | ||
deploy their own TGI endpoint using the TGI framework. For more details, refer to [inference-endpoints](https://huggingface.co/inference-endpoints) | ||
|
||
Input and Output Format: | ||
- String Format: This component uses the str format for structuring both input and output, | ||
Comment on lines
36
to
49
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we remove this "market'y"-sounding docstring and merge the 3 links into the sentence above, similar to:
|
||
|
@@ -55,28 +54,26 @@ class HuggingFaceTGIGenerator: | |
from haystack.utils import Secret | ||
|
||
client = HuggingFaceTGIGenerator(model="mistralai/Mistral-7B-v0.1", token=Secret.from_token("<your-api-key>")) | ||
client.warm_up() | ||
response = client.run("What's Natural Language Processing?", max_new_tokens=120) | ||
print(response) | ||
``` | ||
|
||
Or for LLMs hosted on paid https://huggingface.co/inference-endpoints endpoint, and/or your own custom TGI endpoint. | ||
In these two cases, you'll need to provide the URL of the endpoint as well as a valid token: | ||
In these two cases, you'll need to provide the URL of the endpoint. | ||
For Inference Endpoints, you also need to provide a valid Hugging Face token. | ||
|
||
```python | ||
from haystack.components.generators import HuggingFaceTGIGenerator | ||
client = HuggingFaceTGIGenerator(model="mistralai/Mistral-7B-v0.1", | ||
url="<your-tgi-endpoint-url>", | ||
client = HuggingFaceTGIGenerator(url="<your-tgi-endpoint-url>", | ||
token=Secret.from_token("<your-api-key>")) | ||
client.warm_up() | ||
response = client.run("What's Natural Language Processing?") | ||
print(response) | ||
``` | ||
""" | ||
|
||
def __init__( | ||
self, | ||
model: str = "mistralai/Mistral-7B-v0.1", | ||
model: Optional[str] = None, | ||
url: Optional[str] = None, | ||
token: Optional[Secret] = Secret.from_env_var("HF_API_TOKEN", strict=False), | ||
generation_kwargs: Optional[Dict[str, Any]] = None, | ||
|
@@ -87,10 +84,12 @@ def __init__( | |
Initialize the HuggingFaceTGIGenerator instance. | ||
|
||
:param model: | ||
A string representing the model id on HF Hub. Default is "mistralai/Mistral-7B-v0.1". | ||
An optional string representing the model id on HF Hub. | ||
If not provided, the `url` parameter must be set to a valid TGI endpoint. | ||
:param url: | ||
An optional string representing the URL of the TGI endpoint. If the url is not provided, check if the model | ||
is deployed on the free tier of the HF inference API. | ||
An optional string representing the URL of the TGI endpoint. | ||
If not provided, the `model` parameter must be set to a valid model id and the Hugging Face Inference API | ||
will be used. | ||
:param token: The HuggingFace token to use as HTTP bearer authorization | ||
You can find your HF token in your [account settings](https://huggingface.co/settings/tokens) | ||
:param generation_kwargs: | ||
|
@@ -100,15 +99,25 @@ def __init__( | |
:param stop_words: An optional list of strings representing the stop words. | ||
:param streaming_callback: An optional callable for handling streaming responses. | ||
""" | ||
transformers_import.check() | ||
huggingface_hub_import.check() | ||
|
||
if not model and not url: | ||
warnings.warn( | ||
f"Neither `model` nor `url` is provided. The component will use the default model: {DEFAULT_MODEL}. " | ||
"This behavior is deprecated and will be removed in Haystack 2.3.0.", | ||
DeprecationWarning, | ||
) | ||
model = DEFAULT_MODEL | ||
elif model and url: | ||
logger.warning("Both `model` and `url` are provided. The `model` parameter will be ignored. ") | ||
|
||
if url: | ||
r = urlparse(url) | ||
is_valid_url = all([r.scheme in ["http", "https"], r.netloc]) | ||
if not is_valid_url: | ||
raise ValueError(f"Invalid TGI endpoint URL provided: {url}") | ||
|
||
check_valid_model(model, HFModelType.GENERATION, token) | ||
elif model: | ||
check_valid_model(model, HFModelType.GENERATION, token) | ||
|
||
# handle generation kwargs setup | ||
generation_kwargs = generation_kwargs.copy() if generation_kwargs else {} | ||
|
@@ -121,28 +130,14 @@ def __init__( | |
self.url = url | ||
self.token = token | ||
self.generation_kwargs = generation_kwargs | ||
self.client = InferenceClient(url or model, token=token.resolve_value() if token else None) | ||
self.streaming_callback = streaming_callback | ||
self.tokenizer = None | ||
self._client = InferenceClient(url or model, token=token.resolve_value() if token else None) | ||
|
||
def warm_up(self) -> None: | ||
""" | ||
Initializes the component. | ||
""" | ||
|
||
# is this user using HF free tier inference API? | ||
if self.model and not self.url: | ||
deployed_models = list_inference_deployed_models() | ||
# Determine if the specified model is deployed in the free tier. | ||
if self.model not in deployed_models: | ||
raise ValueError( | ||
f"The model {self.model} is not deployed on the free tier of the HF inference API. " | ||
"To use free tier models provide the model ID and the token. Valid models are: " | ||
f"{deployed_models}" | ||
) | ||
|
||
self.tokenizer = AutoTokenizer.from_pretrained( | ||
self.model, token=self.token.resolve_value() if self.token else None | ||
warnings.warn( | ||
"The `warm_up` method of `HuggingFaceTGIGenerator` does nothing and is momentarily maintained to ensure backward compatibility. " | ||
"It is deprecated and will be removed in Haystack 2.3.0.", | ||
DeprecationWarning, | ||
) | ||
|
||
def to_dict(self) -> Dict[str, Any]: | ||
|
@@ -203,21 +198,16 @@ def run(self, prompt: str, generation_kwargs: Optional[Dict[str, Any]] = None): | |
num_responses = generation_kwargs.pop("n", 1) | ||
generation_kwargs.setdefault("stop_sequences", []).extend(generation_kwargs.pop("stop_words", [])) | ||
|
||
if self.tokenizer is None: | ||
raise RuntimeError("Please call warm_up() before running LLM inference.") | ||
|
||
prompt_token_count = len(self.tokenizer.encode(prompt, add_special_tokens=False)) | ||
|
||
if self.streaming_callback: | ||
if num_responses > 1: | ||
raise ValueError("Cannot stream multiple responses, please set n=1.") | ||
|
||
return self._run_streaming(prompt, prompt_token_count, generation_kwargs) | ||
return self._run_streaming(prompt, generation_kwargs) | ||
|
||
return self._run_non_streaming(prompt, prompt_token_count, num_responses, generation_kwargs) | ||
return self._run_non_streaming(prompt, num_responses, generation_kwargs) | ||
|
||
def _run_streaming(self, prompt: str, prompt_token_count: int, generation_kwargs: Dict[str, Any]): | ||
res_chunk: Iterable[TextGenerationStreamOutput] = self.client.text_generation( | ||
def _run_streaming(self, prompt: str, generation_kwargs: Dict[str, Any]): | ||
res_chunk: Iterable[TextGenerationStreamOutput] = self._client.text_generation( | ||
prompt, details=True, stream=True, **generation_kwargs | ||
) | ||
chunks: List[StreamingChunk] = [] | ||
|
@@ -232,32 +222,22 @@ def _run_streaming(self, prompt: str, prompt_token_count: int, generation_kwargs | |
self.streaming_callback(stream_chunk) # type: ignore # streaming_callback is not None (verified in the run method) | ||
metadata = { | ||
"finish_reason": chunks[-1].meta.get("finish_reason", None), | ||
"model": self.client.model, | ||
"usage": { | ||
"completion_tokens": chunks[-1].meta.get("generated_tokens", 0), | ||
"prompt_tokens": prompt_token_count, | ||
"total_tokens": prompt_token_count + chunks[-1].meta.get("generated_tokens", 0), | ||
}, | ||
"model": self._client.model, | ||
"usage": {"completion_tokens": chunks[-1].meta.get("generated_tokens", 0)}, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's keep the keys with values of zero until 2.3.0. |
||
} | ||
return {"replies": ["".join([chunk.content for chunk in chunks])], "meta": [metadata]} | ||
|
||
def _run_non_streaming( | ||
self, prompt: str, prompt_token_count: int, num_responses: int, generation_kwargs: Dict[str, Any] | ||
): | ||
def _run_non_streaming(self, prompt: str, num_responses: int, generation_kwargs: Dict[str, Any]): | ||
responses: List[str] = [] | ||
all_metadata: List[Dict[str, Any]] = [] | ||
for _i in range(num_responses): | ||
tgr: TextGenerationOutput = self.client.text_generation(prompt, details=True, **generation_kwargs) | ||
tgr: TextGenerationOutput = self._client.text_generation(prompt, details=True, **generation_kwargs) | ||
all_metadata.append( | ||
{ | ||
"model": self.client.model, | ||
"model": self._client.model, | ||
"index": _i, | ||
"finish_reason": tgr.details.finish_reason, | ||
"usage": { | ||
"completion_tokens": len(tgr.details.tokens), | ||
"prompt_tokens": prompt_token_count, | ||
"total_tokens": prompt_token_count + len(tgr.details.tokens), | ||
}, | ||
"usage": {"completion_tokens": len(tgr.details.tokens)}, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same as above. |
||
} | ||
) | ||
responses.append(tgr.generated_text) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
--- | ||
enhancements: | ||
- | | ||
Improve the HuggingFaceTGIGenerator component. | ||
- Lighter: the component only depends on the `huggingface_hub` library. | ||
- If initialized with an appropriate `url` parameter, the component can run on a local network and does not require internet access. | ||
deprecations: | ||
- | | ||
- The HuggingFaceTGIGenerator component requires specifying either a `url` or `model` parameter. | ||
Starting from Haystack 2.3.0, the component will raise an error if neither parameter is provided. | ||
- The `warm_up` method of the HuggingFaceTGIGenerator component is deprecated and will be removed in 2.3.0 release. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should also mention the removal of the keys in the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's instead open an issue for this (and the other deprecated changes) and add it to the 2.3.0 milestone (after creating it). We can add a link to this issue here.