diff --git a/haystack/components/generators/hugging_face_local.py b/haystack/components/generators/hugging_face_local.py index 195226acbf..60058e73c4 100644 --- a/haystack/components/generators/hugging_face_local.py +++ b/haystack/components/generators/hugging_face_local.py @@ -1,9 +1,16 @@ -from typing import Any, Dict, List, Literal, Optional +from typing import Any, Callable, Dict, List, Literal, Optional from haystack import component, default_from_dict, default_to_dict, logging +from haystack.dataclasses import StreamingChunk from haystack.lazy_imports import LazyImport -from haystack.utils import ComponentDevice, Secret, deserialize_secrets_inplace -from haystack.utils.hf import deserialize_hf_model_kwargs, serialize_hf_model_kwargs +from haystack.utils import ( + ComponentDevice, + Secret, + deserialize_callable, + deserialize_secrets_inplace, + serialize_callable, +) +from haystack.utils.hf import HFTokenStreamingHandler, deserialize_hf_model_kwargs, serialize_hf_model_kwargs logger = logging.getLogger(__name__) @@ -48,6 +55,7 @@ def __init__( generation_kwargs: Optional[Dict[str, Any]] = None, huggingface_pipeline_kwargs: Optional[Dict[str, Any]] = None, stop_words: Optional[List[str]] = None, + streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, ): """ Creates an instance of a HuggingFaceLocalGenerator. @@ -81,6 +89,7 @@ def __init__( If you provide this parameter, you should not specify the `stopping_criteria` in `generation_kwargs`. For some chat models, the output includes both the new text and the original prompt. In these cases, it's important to make sure your prompt has no stop words. + :param streaming_callback: An optional callable for handling streaming responses. """ transformers_import.check() @@ -129,6 +138,7 @@ def __init__( self.stop_words = stop_words self.pipeline = None self.stopping_criteria_list = None + self.streaming_callback = streaming_callback def _get_telemetry_data(self) -> Dict[str, Any]: """ @@ -158,10 +168,12 @@ def to_dict(self) -> Dict[str, Any]: :returns: Dictionary with serialized data. """ + callback_name = serialize_callable(self.streaming_callback) if self.streaming_callback else None serialization_dict = default_to_dict( self, huggingface_pipeline_kwargs=self.huggingface_pipeline_kwargs, generation_kwargs=self.generation_kwargs, + streaming_callback=callback_name, stop_words=self.stop_words, token=self.token.to_dict() if self.token else None, ) @@ -184,6 +196,11 @@ def from_dict(cls, data: Dict[str, Any]) -> "HuggingFaceLocalGenerator": """ deserialize_secrets_inplace(data["init_parameters"], keys=["token"]) deserialize_hf_model_kwargs(data["init_parameters"]["huggingface_pipeline_kwargs"]) + init_params = data.get("init_parameters", {}) + serialized_callback_handler = init_params.get("streaming_callback") + if serialized_callback_handler: + data["init_parameters"]["streaming_callback"] = deserialize_callable(serialized_callback_handler) + return default_from_dict(cls, data) @component.output_types(replies=List[str]) @@ -209,6 +226,21 @@ def run(self, prompt: str, generation_kwargs: Optional[Dict[str, Any]] = None): # merge generation kwargs from init method with those from run method updated_generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})} + if self.streaming_callback: + num_responses = updated_generation_kwargs.get("num_return_sequences", 1) + if num_responses > 1: + logger.warning( + "Streaming is enabled, but the number of responses is set to %d. " + "Streaming is only supported for single response generation. " + "Setting the number of responses to 1.", + num_responses, + ) + updated_generation_kwargs["num_return_sequences"] = 1 + # streamer parameter hooks into HF streaming, HFTokenStreamingHandler is an adapter to our streaming + updated_generation_kwargs["streamer"] = HFTokenStreamingHandler( + self.pipeline.tokenizer, self.streaming_callback, self.stop_words + ) + output = self.pipeline(prompt, stopping_criteria=self.stopping_criteria_list, **updated_generation_kwargs) replies = [o["generated_text"] for o in output if "generated_text" in o] diff --git a/releasenotes/notes/hugging-face-local-generator-streaming-callback-38a77d37199f9672.yaml b/releasenotes/notes/hugging-face-local-generator-streaming-callback-38a77d37199f9672.yaml new file mode 100644 index 0000000000..47b1f6776e --- /dev/null +++ b/releasenotes/notes/hugging-face-local-generator-streaming-callback-38a77d37199f9672.yaml @@ -0,0 +1,4 @@ +--- +features: + - | + Adds 'streaming_callback' parameter to 'HuggingFaceLocalGenerator', allowing users to handle streaming responses. diff --git a/test/components/generators/test_hugging_face_local_generator.py b/test/components/generators/test_hugging_face_local_generator.py index 714decc56a..a95e61ff0d 100644 --- a/test/components/generators/test_hugging_face_local_generator.py +++ b/test/components/generators/test_hugging_face_local_generator.py @@ -4,10 +4,11 @@ import pytest import torch from transformers import PreTrainedTokenizerFast -from haystack.utils.auth import Secret from haystack.components.generators.hugging_face_local import HuggingFaceLocalGenerator, StopWordsCriteria from haystack.utils import ComponentDevice +from haystack.utils.auth import Secret +from haystack.utils.hf import HFTokenStreamingHandler class TestHuggingFaceLocalGenerator: @@ -153,7 +154,8 @@ def test_to_dict_default(self, model_info_mock): "task": "text2text-generation", "device": ComponentDevice.resolve_device(None).to_hf(), }, - "generation_kwargs": {}, + "generation_kwargs": {"max_new_tokens": 512}, + "streaming_callback": None, "stop_words": None, }, } @@ -194,6 +196,7 @@ def test_to_dict_with_parameters(self): }, }, "generation_kwargs": {"max_new_tokens": 100, "return_full_text": False}, + "streaming_callback": None, "stop_words": ["coca", "cola"], }, } @@ -238,6 +241,7 @@ def test_to_dict_with_quantization_config(self): }, }, "generation_kwargs": {"max_new_tokens": 100, "return_full_text": False}, + "streaming_callback": None, "stop_words": ["coca", "cola"], }, } @@ -350,6 +354,29 @@ def test_run_with_generation_kwargs(self): "irrelevant", max_new_tokens=200, temperature=0.5, stopping_criteria=None ) + def test_run_with_streaming(self): + def streaming_callback_handler(x): + return x + + generator = HuggingFaceLocalGenerator( + model="google/flan-t5-base", task="text2text-generation", streaming_callback=streaming_callback_handler + ) + + # create the pipeline object (simulating the warm_up) + generator.pipeline = Mock(return_value=[{"generated_text": "Rome"}]) + + generator.run(prompt="irrelevant") + + # when we use streaming, the pipeline should be called with the `streamer` argument being an instance of + # ouf our adapter class HFTokenStreamingHandler + assert isinstance(generator.pipeline.call_args.kwargs["streamer"], HFTokenStreamingHandler) + streamer = generator.pipeline.call_args.kwargs["streamer"] + + # check that the streaming callback is set + assert streamer.token_handler == streaming_callback_handler + # the tokenizer should be set, here it is a mock + assert streamer.tokenizer + def test_run_fails_without_warm_up(self): generator = HuggingFaceLocalGenerator( model="google/flan-t5-base", task="text2text-generation", generation_kwargs={"max_new_tokens": 100}