Skip to content

Commit

Permalink
feat: Add streaming to HuggingFaceLocalGenerator (#7377)
Browse files Browse the repository at this point in the history
* Inital streaming impl

* Add unit tests

* Add release note
  • Loading branch information
vblagoje authored and silvanocerza committed Apr 8, 2024
1 parent fb19ae6 commit d1090d9
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 5 deletions.
38 changes: 35 additions & 3 deletions haystack/components/generators/hugging_face_local.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,16 @@
from typing import Any, Dict, List, Literal, Optional
from typing import Any, Callable, Dict, List, Literal, Optional

from haystack import component, default_from_dict, default_to_dict, logging
from haystack.dataclasses import StreamingChunk
from haystack.lazy_imports import LazyImport
from haystack.utils import ComponentDevice, Secret, deserialize_secrets_inplace
from haystack.utils.hf import deserialize_hf_model_kwargs, serialize_hf_model_kwargs
from haystack.utils import (
ComponentDevice,
Secret,
deserialize_callable,
deserialize_secrets_inplace,
serialize_callable,
)
from haystack.utils.hf import HFTokenStreamingHandler, deserialize_hf_model_kwargs, serialize_hf_model_kwargs

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -48,6 +55,7 @@ def __init__(
generation_kwargs: Optional[Dict[str, Any]] = None,
huggingface_pipeline_kwargs: Optional[Dict[str, Any]] = None,
stop_words: Optional[List[str]] = None,
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
):
"""
Creates an instance of a HuggingFaceLocalGenerator.
Expand Down Expand Up @@ -81,6 +89,7 @@ def __init__(
If you provide this parameter, you should not specify the `stopping_criteria` in `generation_kwargs`.
For some chat models, the output includes both the new text and the original prompt.
In these cases, it's important to make sure your prompt has no stop words.
:param streaming_callback: An optional callable for handling streaming responses.
"""
transformers_import.check()

Expand Down Expand Up @@ -129,6 +138,7 @@ def __init__(
self.stop_words = stop_words
self.pipeline = None
self.stopping_criteria_list = None
self.streaming_callback = streaming_callback

def _get_telemetry_data(self) -> Dict[str, Any]:
"""
Expand Down Expand Up @@ -158,10 +168,12 @@ def to_dict(self) -> Dict[str, Any]:
:returns:
Dictionary with serialized data.
"""
callback_name = serialize_callable(self.streaming_callback) if self.streaming_callback else None
serialization_dict = default_to_dict(
self,
huggingface_pipeline_kwargs=self.huggingface_pipeline_kwargs,
generation_kwargs=self.generation_kwargs,
streaming_callback=callback_name,
stop_words=self.stop_words,
token=self.token.to_dict() if self.token else None,
)
Expand All @@ -184,6 +196,11 @@ def from_dict(cls, data: Dict[str, Any]) -> "HuggingFaceLocalGenerator":
"""
deserialize_secrets_inplace(data["init_parameters"], keys=["token"])
deserialize_hf_model_kwargs(data["init_parameters"]["huggingface_pipeline_kwargs"])
init_params = data.get("init_parameters", {})
serialized_callback_handler = init_params.get("streaming_callback")
if serialized_callback_handler:
data["init_parameters"]["streaming_callback"] = deserialize_callable(serialized_callback_handler)

return default_from_dict(cls, data)

@component.output_types(replies=List[str])
Expand All @@ -209,6 +226,21 @@ def run(self, prompt: str, generation_kwargs: Optional[Dict[str, Any]] = None):
# merge generation kwargs from init method with those from run method
updated_generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})}

if self.streaming_callback:
num_responses = updated_generation_kwargs.get("num_return_sequences", 1)
if num_responses > 1:
logger.warning(
"Streaming is enabled, but the number of responses is set to %d. "
"Streaming is only supported for single response generation. "
"Setting the number of responses to 1.",
num_responses,
)
updated_generation_kwargs["num_return_sequences"] = 1
# streamer parameter hooks into HF streaming, HFTokenStreamingHandler is an adapter to our streaming
updated_generation_kwargs["streamer"] = HFTokenStreamingHandler(
self.pipeline.tokenizer, self.streaming_callback, self.stop_words
)

output = self.pipeline(prompt, stopping_criteria=self.stopping_criteria_list, **updated_generation_kwargs)
replies = [o["generated_text"] for o in output if "generated_text" in o]

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
---
features:
- |
Adds 'streaming_callback' parameter to 'HuggingFaceLocalGenerator', allowing users to handle streaming responses.
31 changes: 29 additions & 2 deletions test/components/generators/test_hugging_face_local_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@
import pytest
import torch
from transformers import PreTrainedTokenizerFast
from haystack.utils.auth import Secret

from haystack.components.generators.hugging_face_local import HuggingFaceLocalGenerator, StopWordsCriteria
from haystack.utils import ComponentDevice
from haystack.utils.auth import Secret
from haystack.utils.hf import HFTokenStreamingHandler


class TestHuggingFaceLocalGenerator:
Expand Down Expand Up @@ -153,7 +154,8 @@ def test_to_dict_default(self, model_info_mock):
"task": "text2text-generation",
"device": ComponentDevice.resolve_device(None).to_hf(),
},
"generation_kwargs": {},
"generation_kwargs": {"max_new_tokens": 512},
"streaming_callback": None,
"stop_words": None,
},
}
Expand Down Expand Up @@ -194,6 +196,7 @@ def test_to_dict_with_parameters(self):
},
},
"generation_kwargs": {"max_new_tokens": 100, "return_full_text": False},
"streaming_callback": None,
"stop_words": ["coca", "cola"],
},
}
Expand Down Expand Up @@ -238,6 +241,7 @@ def test_to_dict_with_quantization_config(self):
},
},
"generation_kwargs": {"max_new_tokens": 100, "return_full_text": False},
"streaming_callback": None,
"stop_words": ["coca", "cola"],
},
}
Expand Down Expand Up @@ -350,6 +354,29 @@ def test_run_with_generation_kwargs(self):
"irrelevant", max_new_tokens=200, temperature=0.5, stopping_criteria=None
)

def test_run_with_streaming(self):
def streaming_callback_handler(x):
return x

generator = HuggingFaceLocalGenerator(
model="google/flan-t5-base", task="text2text-generation", streaming_callback=streaming_callback_handler
)

# create the pipeline object (simulating the warm_up)
generator.pipeline = Mock(return_value=[{"generated_text": "Rome"}])

generator.run(prompt="irrelevant")

# when we use streaming, the pipeline should be called with the `streamer` argument being an instance of
# ouf our adapter class HFTokenStreamingHandler
assert isinstance(generator.pipeline.call_args.kwargs["streamer"], HFTokenStreamingHandler)
streamer = generator.pipeline.call_args.kwargs["streamer"]

# check that the streaming callback is set
assert streamer.token_handler == streaming_callback_handler
# the tokenizer should be set, here it is a mock
assert streamer.tokenizer

def test_run_fails_without_warm_up(self):
generator = HuggingFaceLocalGenerator(
model="google/flan-t5-base", task="text2text-generation", generation_kwargs={"max_new_tokens": 100}
Expand Down

0 comments on commit d1090d9

Please sign in to comment.