From 4969ae4809db692e044c2eaa71ef1403956deab8 Mon Sep 17 00:00:00 2001 From: anakin87 <stefanofiorucci@gmail.com> Date: Fri, 22 Mar 2024 18:52:03 +0100 Subject: [PATCH 1/9] some progress --- .../components/generators/hugging_face_tgi.py | 90 +++++++------------ haystack/utils/hf.py | 8 +- .../generators/test_hugging_face_tgi.py | 4 +- 3 files changed, 39 insertions(+), 63 deletions(-) diff --git a/haystack/components/generators/hugging_face_tgi.py b/haystack/components/generators/hugging_face_tgi.py index 06b065be40..011f324893 100644 --- a/haystack/components/generators/hugging_face_tgi.py +++ b/haystack/components/generators/hugging_face_tgi.py @@ -8,10 +8,9 @@ from haystack.utils import Secret, deserialize_callable, deserialize_secrets_inplace, serialize_callable from haystack.utils.hf import HFModelType, check_generation_params, check_valid_model, list_inference_deployed_models -with LazyImport(message="Run 'pip install transformers'") as transformers_import: +with LazyImport(message="Run 'pip install huggingface_hub'") as huggingface_hub_import: from huggingface_hub import InferenceClient from huggingface_hub.inference._text_generation import TextGenerationResponse, TextGenerationStreamResponse, Token - from transformers import AutoTokenizer logger = logging.getLogger(__name__) @@ -61,8 +60,7 @@ class HuggingFaceTGIGenerator: ```python from haystack.components.generators import HuggingFaceTGIGenerator - client = HuggingFaceTGIGenerator(model="mistralai/Mistral-7B-v0.1", - url="<your-tgi-endpoint-url>", + client = HuggingFaceTGIGenerator(url="<your-tgi-endpoint-url>", token=Secret.from_token("<your-api-key>")) client.warm_up() response = client.run("What's Natural Language Processing?") @@ -72,7 +70,7 @@ class HuggingFaceTGIGenerator: def __init__( self, - model: str = "mistralai/Mistral-7B-v0.1", + model: Optional[str] = None, url: Optional[str] = None, token: Optional[Secret] = Secret.from_env_var("HF_API_TOKEN", strict=False), generation_kwargs: Optional[Dict[str, Any]] = None, @@ -83,10 +81,12 @@ def __init__( Initialize the HuggingFaceTGIGenerator instance. :param model: - A string representing the model id on HF Hub. Default is "mistralai/Mistral-7B-v0.1". + An optional string representing the model id on HF Hub. + If not provided, the `url` parameter must be set to a valid TGI endpoint. :param url: - An optional string representing the URL of the TGI endpoint. If the url is not provided, check if the model - is deployed on the free tier of the HF inference API. + An optional string representing the URL of the TGI endpoint. + If not provided, the `model` parameter must be set to a valid model id and the Hugging Face Inference API + will be used. :param token: The HuggingFace token to use as HTTP bearer authorization You can find your HF token in your [account settings](https://huggingface.co/settings/tokens) :param generation_kwargs: @@ -96,7 +96,10 @@ def __init__( :param stop_words: An optional list of strings representing the stop words. :param streaming_callback: An optional callable for handling streaming responses. """ - transformers_import.check() + huggingface_hub_import.check() + + if (not model and not url) or (model and url): + raise ValueError("You must provide either a model or a TGI endpoint URL.") if url: r = urlparse(url) @@ -104,7 +107,16 @@ def __init__( if not is_valid_url: raise ValueError(f"Invalid TGI endpoint URL provided: {url}") - check_valid_model(model, HFModelType.GENERATION, token) + if model: + check_valid_model(model, HFModelType.GENERATION, token) + # TODO: remove this check when the huggingface_hub bugfix release is out + # https://github.com/huggingface/huggingface_hub/issues/2135 + tgi_deployed_models = list_inference_deployed_models() + if model not in tgi_deployed_models: + raise ValueError( + f"The model {model} is not correctly supported by the free tier of the HF inference API. " + f"Valid models are: {tgi_deployed_models}" + ) # handle generation kwargs setup generation_kwargs = generation_kwargs.copy() if generation_kwargs else {} @@ -117,29 +129,8 @@ def __init__( self.url = url self.token = token self.generation_kwargs = generation_kwargs - self.client = InferenceClient(url or model, token=token.resolve_value() if token else None) + self._client = InferenceClient(url or model, token=token.resolve_value() if token else None) self.streaming_callback = streaming_callback - self.tokenizer = None - - def warm_up(self) -> None: - """ - Initializes the component. - """ - - # is this user using HF free tier inference API? - if self.model and not self.url: - deployed_models = list_inference_deployed_models() - # Determine if the specified model is deployed in the free tier. - if self.model not in deployed_models: - raise ValueError( - f"The model {self.model} is not deployed on the free tier of the HF inference API. " - "To use free tier models provide the model ID and the token. Valid models are: " - f"{deployed_models}" - ) - - self.tokenizer = AutoTokenizer.from_pretrained( - self.model, token=self.token.resolve_value() if self.token else None - ) def to_dict(self) -> Dict[str, Any]: """ @@ -199,21 +190,16 @@ def run(self, prompt: str, generation_kwargs: Optional[Dict[str, Any]] = None): num_responses = generation_kwargs.pop("n", 1) generation_kwargs.setdefault("stop_sequences", []).extend(generation_kwargs.pop("stop_words", [])) - if self.tokenizer is None: - raise RuntimeError("Please call warm_up() before running LLM inference.") - - prompt_token_count = len(self.tokenizer.encode(prompt, add_special_tokens=False)) - if self.streaming_callback: if num_responses > 1: raise ValueError("Cannot stream multiple responses, please set n=1.") - return self._run_streaming(prompt, prompt_token_count, generation_kwargs) + return self._run_streaming(prompt, generation_kwargs) - return self._run_non_streaming(prompt, prompt_token_count, num_responses, generation_kwargs) + return self._run_non_streaming(prompt, num_responses, generation_kwargs) - def _run_streaming(self, prompt: str, prompt_token_count: int, generation_kwargs: Dict[str, Any]): - res_chunk: Iterable[TextGenerationStreamResponse] = self.client.text_generation( + def _run_streaming(self, prompt: str, generation_kwargs: Dict[str, Any]): + res_chunk: Iterable[TextGenerationStreamResponse] = self._client.text_generation( prompt, details=True, stream=True, **generation_kwargs ) chunks: List[StreamingChunk] = [] @@ -228,32 +214,22 @@ def _run_streaming(self, prompt: str, prompt_token_count: int, generation_kwargs self.streaming_callback(stream_chunk) # type: ignore # streaming_callback is not None (verified in the run method) metadata = { "finish_reason": chunks[-1].meta.get("finish_reason", None), - "model": self.client.model, - "usage": { - "completion_tokens": chunks[-1].meta.get("generated_tokens", 0), - "prompt_tokens": prompt_token_count, - "total_tokens": prompt_token_count + chunks[-1].meta.get("generated_tokens", 0), - }, + "model": self._client.model, + "usage": {"completion_tokens": chunks[-1].meta.get("generated_tokens", 0)}, } return {"replies": ["".join([chunk.content for chunk in chunks])], "meta": [metadata]} - def _run_non_streaming( - self, prompt: str, prompt_token_count: int, num_responses: int, generation_kwargs: Dict[str, Any] - ): + def _run_non_streaming(self, prompt: str, num_responses: int, generation_kwargs: Dict[str, Any]): responses: List[str] = [] all_metadata: List[Dict[str, Any]] = [] for _i in range(num_responses): - tgr: TextGenerationResponse = self.client.text_generation(prompt, details=True, **generation_kwargs) + tgr: TextGenerationResponse = self._client.text_generation(prompt, details=True, **generation_kwargs) all_metadata.append( { - "model": self.client.model, + "model": self._client.model, "index": _i, "finish_reason": tgr.details.finish_reason.value, - "usage": { - "completion_tokens": len(tgr.details.tokens), - "prompt_tokens": prompt_token_count, - "total_tokens": prompt_token_count + len(tgr.details.tokens), - }, + "usage": {"completion_tokens": len(tgr.details.tokens)}, } ) responses.append(tgr.generated_text) diff --git a/haystack/utils/hf.py b/haystack/utils/hf.py index b3afe20c17..ace6f4a379 100644 --- a/haystack/utils/hf.py +++ b/haystack/utils/hf.py @@ -14,7 +14,7 @@ with LazyImport(message="Run 'pip install transformers[torch]'") as torch_import: import torch -with LazyImport(message="Run 'pip install transformers'") as transformers_import: +with LazyImport(message="Run 'pip install huggingface_hub'") as huggingface_hub_import: from huggingface_hub import HfApi, InferenceClient, model_info from huggingface_hub.utils import RepositoryNotFoundError @@ -120,7 +120,7 @@ def resolve_hf_pipeline_kwargs( :param token: The token to use as HTTP bearer authorization for remote files. If the token is also specified in the `huggingface_pipeline_kwargs`, this parameter will be ignored. """ - transformers_import.check() + huggingface_hub_import.check() token = token.resolve_value() if token else None # check if the huggingface_pipeline_kwargs contain the essential parameters @@ -173,7 +173,7 @@ def check_valid_model(model_id: str, model_type: HFModelType, token: Optional[Se :param token: The optional authentication token. :raises ValueError: If the model is not found or is not a embedding model. """ - transformers_import.check() + huggingface_hub_import.check() api = HfApi() try: @@ -202,7 +202,7 @@ def check_generation_params(kwargs: Optional[Dict[str, Any]], additional_accepte :param additional_accepted_params: An optional list of strings representing additional accepted parameters. :raises ValueError: If any unknown text generation parameters are provided. """ - transformers_import.check() + huggingface_hub_import.check() if kwargs: accepted_params = { diff --git a/test/components/generators/test_hugging_face_tgi.py b/test/components/generators/test_hugging_face_tgi.py index 19588e938d..a4be204c0f 100644 --- a/test/components/generators/test_hugging_face_tgi.py +++ b/test/components/generators/test_hugging_face_tgi.py @@ -68,13 +68,13 @@ def test_initialize_with_valid_model_and_generation_parameters(self, mock_check_ **{"stop_sequences": ["stop"]}, **{"max_new_tokens": 512}, } - assert generator.tokenizer is None - assert generator.client is not None + assert generator._client is not None assert generator.streaming_callback == streaming_callback def test_to_dict(self, mock_check_valid_model): # Initialize the HuggingFaceRemoteGenerator object with valid parameters generator = HuggingFaceTGIGenerator( + model="mistralai/Mistral-7B-v0.1", token=Secret.from_env_var("ENV_VAR", strict=False), generation_kwargs={"n": 5}, stop_words=["stop", "words"], From cebc8a9799538a68f42c180747d254c247c87573 Mon Sep 17 00:00:00 2001 From: anakin87 <stefanofiorucci@gmail.com> Date: Fri, 22 Mar 2024 19:59:04 +0100 Subject: [PATCH 2/9] more progress --- haystack/components/generators/hugging_face_tgi.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/haystack/components/generators/hugging_face_tgi.py b/haystack/components/generators/hugging_face_tgi.py index 011f324893..cda33eea5b 100644 --- a/haystack/components/generators/hugging_face_tgi.py +++ b/haystack/components/generators/hugging_face_tgi.py @@ -98,7 +98,7 @@ def __init__( """ huggingface_hub_import.check() - if (not model and not url) or (model and url): + if not model and not url: raise ValueError("You must provide either a model or a TGI endpoint URL.") if url: @@ -106,8 +106,10 @@ def __init__( is_valid_url = all([r.scheme in ["http", "https"], r.netloc]) if not is_valid_url: raise ValueError(f"Invalid TGI endpoint URL provided: {url}") + if model: + logger.warning("Both model and url are provided. The model parameter will be ignored. ") - if model: + if model and not url: check_valid_model(model, HFModelType.GENERATION, token) # TODO: remove this check when the huggingface_hub bugfix release is out # https://github.com/huggingface/huggingface_hub/issues/2135 @@ -129,8 +131,8 @@ def __init__( self.url = url self.token = token self.generation_kwargs = generation_kwargs - self._client = InferenceClient(url or model, token=token.resolve_value() if token else None) self.streaming_callback = streaming_callback + self._client = InferenceClient(url or model, token=token.resolve_value() if token else None) def to_dict(self) -> Dict[str, Any]: """ From 75a4e992c3f82d6a39acd07455bfe7ab234468b6 Mon Sep 17 00:00:00 2001 From: anakin87 <stefanofiorucci@gmail.com> Date: Fri, 22 Mar 2024 20:34:58 +0100 Subject: [PATCH 3/9] improve tests --- .../components/generators/hugging_face_tgi.py | 6 +-- .../generators/test_hugging_face_tgi.py | 40 +++++++++---------- 2 files changed, 20 insertions(+), 26 deletions(-) diff --git a/haystack/components/generators/hugging_face_tgi.py b/haystack/components/generators/hugging_face_tgi.py index cda33eea5b..b860fb726b 100644 --- a/haystack/components/generators/hugging_face_tgi.py +++ b/haystack/components/generators/hugging_face_tgi.py @@ -50,7 +50,6 @@ class HuggingFaceTGIGenerator: from haystack.utils import Secret client = HuggingFaceTGIGenerator(model="mistralai/Mistral-7B-v0.1", token=Secret.from_token("<your-api-key>")) - client.warm_up() response = client.run("What's Natural Language Processing?", max_new_tokens=120) print(response) ``` @@ -62,7 +61,6 @@ class HuggingFaceTGIGenerator: from haystack.components.generators import HuggingFaceTGIGenerator client = HuggingFaceTGIGenerator(url="<your-tgi-endpoint-url>", token=Secret.from_token("<your-api-key>")) - client.warm_up() response = client.run("What's Natural Language Processing?") print(response) ``` @@ -100,14 +98,14 @@ def __init__( if not model and not url: raise ValueError("You must provide either a model or a TGI endpoint URL.") + if model and url: + logger.warning("Both `model` and `url` are provided. The `model` parameter will be ignored. ") if url: r = urlparse(url) is_valid_url = all([r.scheme in ["http", "https"], r.netloc]) if not is_valid_url: raise ValueError(f"Invalid TGI endpoint URL provided: {url}") - if model: - logger.warning("Both model and url are provided. The model parameter will be ignored. ") if model and not url: check_valid_model(model, HFModelType.GENERATION, token) diff --git a/test/components/generators/test_hugging_face_tgi.py b/test/components/generators/test_hugging_face_tgi.py index a4be204c0f..08106949c6 100644 --- a/test/components/generators/test_hugging_face_tgi.py +++ b/test/components/generators/test_hugging_face_tgi.py @@ -47,7 +47,9 @@ def streaming_callback_handler(x): class TestHuggingFaceTGIGenerator: - def test_initialize_with_valid_model_and_generation_parameters(self, mock_check_valid_model): + def test_initialize_with_valid_model_and_generation_parameters( + self, mock_check_valid_model, mock_list_inference_deployed_models + ): model = "HuggingFaceH4/zephyr-7b-alpha" generation_kwargs = {"n": 1} stop_words = ["stop"] @@ -71,8 +73,7 @@ def test_initialize_with_valid_model_and_generation_parameters(self, mock_check_ assert generator._client is not None assert generator.streaming_callback == streaming_callback - def test_to_dict(self, mock_check_valid_model): - # Initialize the HuggingFaceRemoteGenerator object with valid parameters + def test_to_dict(self, mock_check_valid_model, mock_list_inference_deployed_models): generator = HuggingFaceTGIGenerator( model="mistralai/Mistral-7B-v0.1", token=Secret.from_env_var("ENV_VAR", strict=False), @@ -90,7 +91,7 @@ def test_to_dict(self, mock_check_valid_model): assert init_params["token"] == {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"} assert init_params["generation_kwargs"] == {"n": 5, "stop_sequences": ["stop", "words"], "max_new_tokens": 512} - def test_from_dict(self, mock_check_valid_model): + def test_from_dict(self, mock_check_valid_model, mock_list_inference_deployed_models): generator = HuggingFaceTGIGenerator( model="mistralai/Mistral-7B-v0.1", generation_kwargs={"n": 5}, @@ -110,14 +111,12 @@ def test_initialize_with_invalid_url(self, mock_check_valid_model): with pytest.raises(ValueError): HuggingFaceTGIGenerator(model="mistralai/Mistral-7B-v0.1", url="invalid_url") - def test_initialize_with_url_but_invalid_model(self, mock_check_valid_model): - # When custom TGI endpoint is used via URL, model must be provided and valid HuggingFace Hub model id - mock_check_valid_model.side_effect = RepositoryNotFoundError("Invalid model id") - with pytest.raises(RepositoryNotFoundError): - HuggingFaceTGIGenerator(model="invalid_model_id", url="https://some_chat_model.com") + def test_initialize_without_model_or_url(self): + with pytest.raises(ValueError): + HuggingFaceTGIGenerator(model=None, url=None) def test_generate_text_response_with_valid_prompt_and_generation_parameters( - self, mock_check_valid_model, mock_auto_tokenizer, mock_text_generation, mock_list_inference_deployed_models + self, mock_check_valid_model, mock_text_generation, mock_list_inference_deployed_models ): model = "mistralai/Mistral-7B-v0.1" @@ -131,7 +130,6 @@ def test_generate_text_response_with_valid_prompt_and_generation_parameters( stop_words=stop_words, streaming_callback=streaming_callback, ) - generator.warm_up() prompt = "Hello, how are you?" response = generator.run(prompt) @@ -151,7 +149,7 @@ def test_generate_text_response_with_valid_prompt_and_generation_parameters( assert [isinstance(reply, str) for reply in response["replies"]] def test_generate_multiple_text_responses_with_valid_prompt_and_generation_parameters( - self, mock_check_valid_model, mock_auto_tokenizer, mock_text_generation, mock_list_inference_deployed_models + self, mock_check_valid_model, mock_text_generation, mock_list_inference_deployed_models ): model = "mistralai/Mistral-7B-v0.1" generation_kwargs = {"n": 3} @@ -164,7 +162,6 @@ def test_generate_multiple_text_responses_with_valid_prompt_and_generation_param stop_words=stop_words, streaming_callback=streaming_callback, ) - generator.warm_up() prompt = "Hello, how are you?" response = generator.run(prompt) @@ -202,10 +199,9 @@ def test_initialize_with_invalid_model(self, mock_check_valid_model): ) def test_generate_text_with_stop_words( - self, mock_check_valid_model, mock_auto_tokenizer, mock_text_generation, mock_list_inference_deployed_models + self, mock_check_valid_model, mock_text_generation, mock_list_inference_deployed_models ): - generator = HuggingFaceTGIGenerator() - generator.warm_up() + generator = HuggingFaceTGIGenerator("HuggingFaceH4/zephyr-7b-alpha") # Generate text response with stop words response = generator.run("How are you?", generation_kwargs={"stop_words": ["stop", "words"]}) @@ -227,10 +223,9 @@ def test_generate_text_with_stop_words( assert [isinstance(reply, dict) for reply in response["replies"]] def test_generate_text_with_custom_generation_parameters( - self, mock_check_valid_model, mock_auto_tokenizer, mock_text_generation, mock_list_inference_deployed_models + self, mock_check_valid_model, mock_text_generation, mock_list_inference_deployed_models ): - generator = HuggingFaceTGIGenerator() - generator.warm_up() + generator = HuggingFaceTGIGenerator("HuggingFaceH4/zephyr-7b-alpha") generation_kwargs = {"temperature": 0.8, "max_new_tokens": 100} response = generator.run("How are you?", generation_kwargs=generation_kwargs) @@ -253,7 +248,7 @@ def test_generate_text_with_custom_generation_parameters( assert [isinstance(reply, str) for reply in response["replies"]] def test_generate_text_with_streaming_callback( - self, mock_check_valid_model, mock_auto_tokenizer, mock_text_generation, mock_list_inference_deployed_models + self, mock_check_valid_model, mock_text_generation, mock_list_inference_deployed_models ): streaming_call_count = 0 @@ -264,8 +259,9 @@ def streaming_callback_fn(chunk: StreamingChunk): assert isinstance(chunk, StreamingChunk) # Create an instance of HuggingFaceRemoteGenerator - generator = HuggingFaceTGIGenerator(streaming_callback=streaming_callback_fn) - generator.warm_up() + generator = HuggingFaceTGIGenerator( + model="HuggingFaceH4/zephyr-7b-alpha", streaming_callback=streaming_callback_fn + ) # Create a fake streamed response # Don't remove self From 647fbcb21434175d32a622f32f95c2494106d1a0 Mon Sep 17 00:00:00 2001 From: anakin87 <stefanofiorucci@gmail.com> Date: Tue, 2 Apr 2024 18:15:54 +0200 Subject: [PATCH 4/9] non-breaking changes --- .../components/generators/hugging_face_tgi.py | 29 +++++++----- .../generators/test_hugging_face_tgi.py | 44 ++++++------------- 2 files changed, 32 insertions(+), 41 deletions(-) diff --git a/haystack/components/generators/hugging_face_tgi.py b/haystack/components/generators/hugging_face_tgi.py index a046082238..7215235f19 100644 --- a/haystack/components/generators/hugging_face_tgi.py +++ b/haystack/components/generators/hugging_face_tgi.py @@ -1,3 +1,4 @@ +import warnings from dataclasses import asdict from typing import Any, Callable, Dict, Iterable, List, Optional from urllib.parse import urlparse @@ -20,6 +21,10 @@ logger = logging.getLogger(__name__) +# TODO: remove the default model in Haystack 2.3.0, as explained in the deprecation warning +DEFAULT_MODEL = "mistralai/Mistral-7B-v0.1" + + @component class HuggingFaceTGIGenerator: """ @@ -101,8 +106,13 @@ def __init__( huggingface_hub_import.check() if not model and not url: - raise ValueError("You must provide either a model or a TGI endpoint URL.") - if model and url: + warnings.warn( + f"Neither `model` nor `url` is provided. The component will use the default model: {DEFAULT_MODEL}. " + "This behavior is deprecated and will be removed in Haystack 2.3.0.", + DeprecationWarning, + ) + model = DEFAULT_MODEL + elif model and url: logger.warning("Both `model` and `url` are provided. The `model` parameter will be ignored. ") if url: @@ -113,14 +123,6 @@ def __init__( if model and not url: check_valid_model(model, HFModelType.GENERATION, token) - # TODO: remove this check when the huggingface_hub bugfix release is out - # https://github.com/huggingface/huggingface_hub/issues/2135 - tgi_deployed_models = list_inference_deployed_models() - if model not in tgi_deployed_models: - raise ValueError( - f"The model {model} is not correctly supported by the free tier of the HF inference API. " - f"Valid models are: {tgi_deployed_models}" - ) # handle generation kwargs setup generation_kwargs = generation_kwargs.copy() if generation_kwargs else {} @@ -136,6 +138,13 @@ def __init__( self.streaming_callback = streaming_callback self._client = InferenceClient(url or model, token=token.resolve_value() if token else None) + def warm_up(self) -> None: + warnings.warn( + "The `warm_up` method of `HuggingFaceTGIGenerator` does nothing and is momentarily maintained to ensure backward compatibility. " + "It is deprecated and will be removed in Haystack 2.3.0.", + DeprecationWarning, + ) + def to_dict(self) -> Dict[str, Any]: """ Serialize this component to a dictionary. diff --git a/test/components/generators/test_hugging_face_tgi.py b/test/components/generators/test_hugging_face_tgi.py index 2978eb161b..aebbd1d5d3 100644 --- a/test/components/generators/test_hugging_face_tgi.py +++ b/test/components/generators/test_hugging_face_tgi.py @@ -4,22 +4,11 @@ from huggingface_hub import TextGenerationOutputToken, TextGenerationStreamDetails, TextGenerationStreamOutput from huggingface_hub.utils import RepositoryNotFoundError -from haystack.components.generators import HuggingFaceTGIGenerator +from haystack.components.generators.hugging_face_tgi import DEFAULT_MODEL, HuggingFaceTGIGenerator from haystack.dataclasses import StreamingChunk from haystack.utils.auth import Secret -@pytest.fixture -def mock_list_inference_deployed_models(): - with patch( - "haystack.components.generators.hugging_face_tgi.list_inference_deployed_models", - MagicMock( - return_value=["HuggingFaceH4/zephyr-7b-alpha", "HuggingFaceH4/zephyr-7b-alpha", "mistralai/Mistral-7B-v0.1"] - ), - ) as mock: - yield mock - - @pytest.fixture def mock_check_valid_model(): with patch( @@ -47,9 +36,7 @@ def streaming_callback_handler(x): class TestHuggingFaceTGIGenerator: - def test_initialize_with_valid_model_and_generation_parameters( - self, mock_check_valid_model, mock_list_inference_deployed_models - ): + def test_initialize_with_valid_model_and_generation_parameters(self, mock_check_valid_model): model = "HuggingFaceH4/zephyr-7b-alpha" generation_kwargs = {"n": 1} stop_words = ["stop"] @@ -73,7 +60,7 @@ def test_initialize_with_valid_model_and_generation_parameters( assert generator._client is not None assert generator.streaming_callback == streaming_callback - def test_to_dict(self, mock_check_valid_model, mock_list_inference_deployed_models): + def test_to_dict(self, mock_check_valid_model): generator = HuggingFaceTGIGenerator( model="mistralai/Mistral-7B-v0.1", token=Secret.from_env_var("ENV_VAR", strict=False), @@ -91,7 +78,7 @@ def test_to_dict(self, mock_check_valid_model, mock_list_inference_deployed_mode assert init_params["token"] == {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"} assert init_params["generation_kwargs"] == {"n": 5, "stop_sequences": ["stop", "words"], "max_new_tokens": 512} - def test_from_dict(self, mock_check_valid_model, mock_list_inference_deployed_models): + def test_from_dict(self, mock_check_valid_model): generator = HuggingFaceTGIGenerator( model="mistralai/Mistral-7B-v0.1", generation_kwargs={"n": 5}, @@ -111,12 +98,13 @@ def test_initialize_with_invalid_url(self, mock_check_valid_model): with pytest.raises(ValueError): HuggingFaceTGIGenerator(model="mistralai/Mistral-7B-v0.1", url="invalid_url") - def test_initialize_without_model_or_url(self): - with pytest.raises(ValueError): - HuggingFaceTGIGenerator(model=None, url=None) + def test_initialize_without_model_or_url(self, mock_check_valid_model): + generator = HuggingFaceTGIGenerator(model=None, url=None) + + assert generator.model == DEFAULT_MODEL def test_generate_text_response_with_valid_prompt_and_generation_parameters( - self, mock_check_valid_model, mock_text_generation, mock_list_inference_deployed_models + self, mock_check_valid_model, mock_text_generation ): model = "mistralai/Mistral-7B-v0.1" @@ -149,7 +137,7 @@ def test_generate_text_response_with_valid_prompt_and_generation_parameters( assert [isinstance(reply, str) for reply in response["replies"]] def test_generate_multiple_text_responses_with_valid_prompt_and_generation_parameters( - self, mock_check_valid_model, mock_text_generation, mock_list_inference_deployed_models + self, mock_check_valid_model, mock_text_generation ): model = "mistralai/Mistral-7B-v0.1" generation_kwargs = {"n": 3} @@ -198,9 +186,7 @@ def test_initialize_with_invalid_model(self, mock_check_valid_model): streaming_callback=streaming_callback, ) - def test_generate_text_with_stop_words( - self, mock_check_valid_model, mock_text_generation, mock_list_inference_deployed_models - ): + def test_generate_text_with_stop_words(self, mock_check_valid_model, mock_text_generation): generator = HuggingFaceTGIGenerator("HuggingFaceH4/zephyr-7b-alpha") # Generate text response with stop words @@ -222,9 +208,7 @@ def test_generate_text_with_stop_words( assert len(response["meta"]) > 0 assert [isinstance(reply, dict) for reply in response["replies"]] - def test_generate_text_with_custom_generation_parameters( - self, mock_check_valid_model, mock_text_generation, mock_list_inference_deployed_models - ): + def test_generate_text_with_custom_generation_parameters(self, mock_check_valid_model, mock_text_generation): generator = HuggingFaceTGIGenerator("HuggingFaceH4/zephyr-7b-alpha") generation_kwargs = {"temperature": 0.8, "max_new_tokens": 100} @@ -247,9 +231,7 @@ def test_generate_text_with_custom_generation_parameters( assert len(response["meta"]) > 0 assert [isinstance(reply, str) for reply in response["replies"]] - def test_generate_text_with_streaming_callback( - self, mock_check_valid_model, mock_text_generation, mock_list_inference_deployed_models - ): + def test_generate_text_with_streaming_callback(self, mock_check_valid_model, mock_text_generation): streaming_call_count = 0 # Define the streaming callback function From b8f0195009c27134d1a29fbef28d3850e40b1d9f Mon Sep 17 00:00:00 2001 From: anakin87 <stefanofiorucci@gmail.com> Date: Tue, 2 Apr 2024 18:28:21 +0200 Subject: [PATCH 5/9] rm import --- test/components/generators/test_hugging_face_tgi.py | 1 - 1 file changed, 1 deletion(-) diff --git a/test/components/generators/test_hugging_face_tgi.py b/test/components/generators/test_hugging_face_tgi.py index aebbd1d5d3..6cfacfeb48 100644 --- a/test/components/generators/test_hugging_face_tgi.py +++ b/test/components/generators/test_hugging_face_tgi.py @@ -2,7 +2,6 @@ import pytest from huggingface_hub import TextGenerationOutputToken, TextGenerationStreamDetails, TextGenerationStreamOutput -from huggingface_hub.utils import RepositoryNotFoundError from haystack.components.generators.hugging_face_tgi import DEFAULT_MODEL, HuggingFaceTGIGenerator from haystack.dataclasses import StreamingChunk From f1effa122b780e5dd94117181785aba1e0f01131 Mon Sep 17 00:00:00 2001 From: anakin87 <stefanofiorucci@gmail.com> Date: Tue, 2 Apr 2024 18:39:00 +0200 Subject: [PATCH 6/9] rm another import --- haystack/components/generators/hugging_face_tgi.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/haystack/components/generators/hugging_face_tgi.py b/haystack/components/generators/hugging_face_tgi.py index 7215235f19..10f6770e36 100644 --- a/haystack/components/generators/hugging_face_tgi.py +++ b/haystack/components/generators/hugging_face_tgi.py @@ -7,7 +7,7 @@ from haystack.dataclasses import StreamingChunk from haystack.lazy_imports import LazyImport from haystack.utils import Secret, deserialize_callable, deserialize_secrets_inplace, serialize_callable -from haystack.utils.hf import HFModelType, check_generation_params, check_valid_model, list_inference_deployed_models +from haystack.utils.hf import HFModelType, check_generation_params, check_valid_model with LazyImport(message="Run 'pip install \"huggingface_hub>=0.22.0\"'") as huggingface_hub_import: from huggingface_hub import ( From 6115b18e0f33e90c7f5b1fe709de5ae682ce177e Mon Sep 17 00:00:00 2001 From: anakin87 <stefanofiorucci@gmail.com> Date: Tue, 2 Apr 2024 18:48:43 +0200 Subject: [PATCH 7/9] release note --- .../notes/tgi-refactoring-62885781f81e18d1.yaml | 11 +++++++++++ 1 file changed, 11 insertions(+) create mode 100644 releasenotes/notes/tgi-refactoring-62885781f81e18d1.yaml diff --git a/releasenotes/notes/tgi-refactoring-62885781f81e18d1.yaml b/releasenotes/notes/tgi-refactoring-62885781f81e18d1.yaml new file mode 100644 index 0000000000..288036ea22 --- /dev/null +++ b/releasenotes/notes/tgi-refactoring-62885781f81e18d1.yaml @@ -0,0 +1,11 @@ +--- +enhancements: + - | + Improve the HuggingFaceTGIGenerator component. + - Lighter: the component only depends on the `huggingface_hub` library. + - If initialized with an appropriate `url` parameter, the component can run on a local network and does not require internet access. +deprecations: + - | + - The HuggingFaceTGIGenerator component requires specifying either a `url` or `model` parameter. + Starting from Haystack 2.3.0, the component will raise an error if neither parameter is provided. + - The `warm_up` method of the HuggingFaceTGIGenerator component is deprecated and will be removed in 2.3.0 release. From 6897afb8d0f2fbd2130ed8ba7ec0e418cc971965 Mon Sep 17 00:00:00 2001 From: anakin87 <stefanofiorucci@gmail.com> Date: Tue, 2 Apr 2024 18:59:10 +0200 Subject: [PATCH 8/9] improve condition --- haystack/components/generators/hugging_face_tgi.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/haystack/components/generators/hugging_face_tgi.py b/haystack/components/generators/hugging_face_tgi.py index 10f6770e36..60cd3ceac3 100644 --- a/haystack/components/generators/hugging_face_tgi.py +++ b/haystack/components/generators/hugging_face_tgi.py @@ -120,8 +120,7 @@ def __init__( is_valid_url = all([r.scheme in ["http", "https"], r.netloc]) if not is_valid_url: raise ValueError(f"Invalid TGI endpoint URL provided: {url}") - - if model and not url: + elif model: check_valid_model(model, HFModelType.GENERATION, token) # handle generation kwargs setup From 04ace00f2d83b51684966b2e7f841b62c34eed65 Mon Sep 17 00:00:00 2001 From: anakin87 <stefanofiorucci@gmail.com> Date: Tue, 2 Apr 2024 22:02:36 +0200 Subject: [PATCH 9/9] improve docstrings --- .../components/generators/hugging_face_tgi.py | 18 +++++++----------- 1 file changed, 7 insertions(+), 11 deletions(-) diff --git a/haystack/components/generators/hugging_face_tgi.py b/haystack/components/generators/hugging_face_tgi.py index 60cd3ceac3..14c841db63 100644 --- a/haystack/components/generators/hugging_face_tgi.py +++ b/haystack/components/generators/hugging_face_tgi.py @@ -31,24 +31,19 @@ class HuggingFaceTGIGenerator: Enables text generation using HuggingFace Hub hosted non-chat LLMs. This component is designed to seamlessly inference models deployed on the Text Generation Inference (TGI) backend. - You can use this component for LLMs hosted on Hugging Face inference endpoints, the rate-limited - Inference API tier. + You can use this component for LLMs hosted on Hugging Face inference endpoints. Key Features and Compatibility: - - Primary Compatibility: designed to work seamlessly with any non-based model deployed using the TGI + - Primary Compatibility: designed to work seamlessly with models deployed using the TGI framework. For more information on TGI, visit [text-generation-inference](https://github.com/huggingface/text-generation-inference) - - Hugging Face Inference Endpoints: Supports inference of TGI chat LLMs deployed on Hugging Face + - Hugging Face Inference Endpoints: Supports inference of LLMs deployed on Hugging Face inference endpoints. For more details, refer to [inference-endpoints](https://huggingface.co/inference-endpoints) - - Inference API Support: supports inference of TGI LLMs hosted on the rate-limited Inference + - Inference API Support: supports inference of LLMs hosted on the rate-limited Inference API tier. Learn more about the Inference API at [inference-api](https://huggingface.co/inference-api). - Discover available chat models using the following command: `wget -qO- https://api-inference.huggingface.co/framework/text-generation-inference | grep chat` - and simply use the model ID as the model parameter for this component. You'll also need to provide a valid - Hugging Face API token as the token parameter. + In this case, you need to provide a valid Hugging Face token. - - Custom TGI Endpoints: supports inference of TGI chat LLMs deployed on custom TGI endpoints. Anyone can - deploy their own TGI endpoint using the TGI framework. For more details, refer to [inference-endpoints](https://huggingface.co/inference-endpoints) Input and Output Format: - String Format: This component uses the str format for structuring both input and output, @@ -64,7 +59,8 @@ class HuggingFaceTGIGenerator: ``` Or for LLMs hosted on paid https://huggingface.co/inference-endpoints endpoint, and/or your own custom TGI endpoint. - In these two cases, you'll need to provide the URL of the endpoint as well as a valid token: + In these two cases, you'll need to provide the URL of the endpoint. + For Inference Endpoints, you also need to provide a valid Hugging Face token. ```python from haystack.components.generators import HuggingFaceTGIGenerator