From 21dd03d3e79774de89d772d2ce37dec11e2cd5e4 Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Fri, 17 Jan 2025 09:58:45 +0100 Subject: [PATCH] feat: Add completion start time timestamp to relevant generators (#8728) * OpenAIChatGenerator - add completion_start_time * HuggingFaceAPIChatGenerator - add completion_start_time * Add tests * Add reno note * Relax condition for cached responses * Add completion_start_time timestamping to non-chat generators * Update haystack/components/generators/chat/hugging_face_api.py Co-authored-by: Stefano Fiorucci * PR feedback --------- Co-authored-by: Stefano Fiorucci --- .../generators/chat/hugging_face_api.py | 6 +++++ haystack/components/generators/chat/openai.py | 3 +++ .../components/generators/hugging_face_api.py | 9 ++++++++ haystack/components/generators/openai.py | 14 +++++++---- ...completion-timestamp-c0ad3b8698a2d575.yaml | 4 ++++ .../generators/chat/test_hugging_face_api.py | 11 ++++++--- .../components/generators/chat/test_openai.py | 4 ++++ .../generators/test_hugging_face_api.py | 23 +++++++++++++++++++ test/components/generators/test_openai.py | 4 ++++ 9 files changed, 71 insertions(+), 7 deletions(-) create mode 100644 releasenotes/notes/add-streaming-completion-timestamp-c0ad3b8698a2d575.yaml diff --git a/haystack/components/generators/chat/hugging_face_api.py b/haystack/components/generators/chat/hugging_face_api.py index 50a730a01f..1264272fca 100644 --- a/haystack/components/generators/chat/hugging_face_api.py +++ b/haystack/components/generators/chat/hugging_face_api.py @@ -2,6 +2,7 @@ # # SPDX-License-Identifier: Apache-2.0 +from datetime import datetime from typing import Any, Callable, Dict, Iterable, List, Optional, Union from haystack import component, default_from_dict, default_to_dict, logging @@ -259,6 +260,7 @@ def _run_streaming(self, messages: List[Dict[str, str]], generation_kwargs: Dict ) generated_text = "" + first_chunk_time = None for chunk in api_output: # n is unused, so the API always returns only one choice @@ -276,6 +278,9 @@ def _run_streaming(self, messages: List[Dict[str, str]], generation_kwargs: Dict if finish_reason: meta["finish_reason"] = finish_reason + if first_chunk_time is None: + first_chunk_time = datetime.now().isoformat() + stream_chunk = StreamingChunk(text, meta) self.streaming_callback(stream_chunk) # type: ignore # streaming_callback is not None (verified in the run method) @@ -285,6 +290,7 @@ def _run_streaming(self, messages: List[Dict[str, str]], generation_kwargs: Dict "finish_reason": finish_reason, "index": 0, "usage": {"prompt_tokens": 0, "completion_tokens": 0}, # not available in streaming + "completion_start_time": first_chunk_time, } ) diff --git a/haystack/components/generators/chat/openai.py b/haystack/components/generators/chat/openai.py index 0b699e3bc1..b30de1b43d 100644 --- a/haystack/components/generators/chat/openai.py +++ b/haystack/components/generators/chat/openai.py @@ -4,6 +4,7 @@ import json import os +from datetime import datetime from typing import Any, Callable, Dict, List, Optional, Union from openai import OpenAI, Stream @@ -381,6 +382,7 @@ def _convert_streaming_chunks_to_chat_message(self, chunk: Any, chunks: List[Str "model": chunk.model, "index": 0, "finish_reason": chunk.choices[0].finish_reason, + "completion_start_time": chunks[0].meta.get("received_at"), # first chunk received "usage": {}, # we don't have usage data for streaming responses } @@ -444,6 +446,7 @@ def _convert_chat_completion_chunk_to_streaming_chunk(self, chunk: ChatCompletio "index": choice.index, "tool_calls": choice.delta.tool_calls, "finish_reason": choice.finish_reason, + "received_at": datetime.now().isoformat(), } ) return chunk_message diff --git a/haystack/components/generators/hugging_face_api.py b/haystack/components/generators/hugging_face_api.py index a44ad94575..0a977f1603 100644 --- a/haystack/components/generators/hugging_face_api.py +++ b/haystack/components/generators/hugging_face_api.py @@ -3,6 +3,7 @@ # SPDX-License-Identifier: Apache-2.0 from dataclasses import asdict +from datetime import datetime from typing import Any, Callable, Dict, Iterable, List, Optional, Union from haystack import component, default_from_dict, default_to_dict, logging @@ -217,18 +218,26 @@ def _stream_and_build_response( self, hf_output: Iterable["TextGenerationStreamOutput"], streaming_callback: Callable[[StreamingChunk], None] ): chunks: List[StreamingChunk] = [] + first_chunk_time = None + for chunk in hf_output: token: TextGenerationOutputToken = chunk.token if token.special: continue + chunk_metadata = {**asdict(token), **(asdict(chunk.details) if chunk.details else {})} + if first_chunk_time is None: + first_chunk_time = datetime.now().isoformat() + stream_chunk = StreamingChunk(token.text, chunk_metadata) chunks.append(stream_chunk) streaming_callback(stream_chunk) + metadata = { "finish_reason": chunks[-1].meta.get("finish_reason", None), "model": self._client.model, "usage": {"completion_tokens": chunks[-1].meta.get("generated_tokens", 0)}, + "completion_start_time": first_chunk_time, } return {"replies": ["".join([chunk.content for chunk in chunks])], "meta": [metadata]} diff --git a/haystack/components/generators/openai.py b/haystack/components/generators/openai.py index d2f07f9d85..3a87b8c068 100644 --- a/haystack/components/generators/openai.py +++ b/haystack/components/generators/openai.py @@ -3,6 +3,7 @@ # SPDX-License-Identifier: Apache-2.0 import os +from datetime import datetime from typing import Any, Callable, Dict, List, Optional, Union from openai import OpenAI, Stream @@ -255,7 +256,7 @@ def _create_message_from_chunks( "model": completion_chunk.model, "index": 0, "finish_reason": finish_reason, - # Usage is available when streaming only if the user explicitly requests it + "completion_start_time": streamed_chunks[0].meta.get("received_at"), # first chunk received "usage": dict(completion_chunk.usage or {}), } ) @@ -296,12 +297,17 @@ def _build_chunk(chunk: Any) -> StreamingChunk: :returns: The StreamingChunk. """ - # function or tools calls are not going to happen in non-chat generation - # as users can not send ChatMessage with function or tools calls choice = chunk.choices[0] content = choice.delta.content or "" chunk_message = StreamingChunk(content) - chunk_message.meta.update({"model": chunk.model, "index": choice.index, "finish_reason": choice.finish_reason}) + chunk_message.meta.update( + { + "model": chunk.model, + "index": choice.index, + "finish_reason": choice.finish_reason, + "received_at": datetime.now().isoformat(), + } + ) return chunk_message @staticmethod diff --git a/releasenotes/notes/add-streaming-completion-timestamp-c0ad3b8698a2d575.yaml b/releasenotes/notes/add-streaming-completion-timestamp-c0ad3b8698a2d575.yaml new file mode 100644 index 0000000000..2718c6fdcd --- /dev/null +++ b/releasenotes/notes/add-streaming-completion-timestamp-c0ad3b8698a2d575.yaml @@ -0,0 +1,4 @@ +--- +enhancements: + - | + Added completion_start_time metadata to track time-to-first-token (TTFT) in streaming responses from Hugging Face API and OpenAI (Azure). diff --git a/test/components/generators/chat/test_hugging_face_api.py b/test/components/generators/chat/test_hugging_face_api.py index fa83b98db7..f9e306c46e 100644 --- a/test/components/generators/chat/test_hugging_face_api.py +++ b/test/components/generators/chat/test_hugging_face_api.py @@ -1,6 +1,7 @@ # SPDX-FileCopyrightText: 2022-present deepset GmbH # # SPDX-License-Identifier: Apache-2.0 +from datetime import datetime import os from unittest.mock import MagicMock, Mock, patch @@ -503,9 +504,13 @@ def test_live_run_serverless_streaming(self): assert isinstance(response["replies"], list) assert len(response["replies"]) > 0 assert [isinstance(reply, ChatMessage) for reply in response["replies"]] - assert "usage" in response["replies"][0].meta - assert "prompt_tokens" in response["replies"][0].meta["usage"] - assert "completion_tokens" in response["replies"][0].meta["usage"] + + response_meta = response["replies"][0].meta + assert "completion_start_time" in response_meta + assert datetime.fromisoformat(response_meta["completion_start_time"]) <= datetime.now() + assert "usage" in response_meta + assert "prompt_tokens" in response_meta["usage"] + assert "completion_tokens" in response_meta["usage"] @pytest.mark.integration @pytest.mark.skipif( diff --git a/test/components/generators/chat/test_openai.py b/test/components/generators/chat/test_openai.py index eb50d92739..63a920a8ec 100644 --- a/test/components/generators/chat/test_openai.py +++ b/test/components/generators/chat/test_openai.py @@ -546,6 +546,10 @@ def __call__(self, chunk: StreamingChunk) -> None: assert callback.counter > 1 assert "Paris" in callback.responses + # check that the completion_start_time is set and valid ISO format + assert "completion_start_time" in message.meta + assert datetime.fromisoformat(message.meta["completion_start_time"]) < datetime.now() + @pytest.mark.skipif( not os.environ.get("OPENAI_API_KEY", None), reason="Export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.", diff --git a/test/components/generators/test_hugging_face_api.py b/test/components/generators/test_hugging_face_api.py index 0f4be2f9cb..965cf0cf81 100644 --- a/test/components/generators/test_hugging_face_api.py +++ b/test/components/generators/test_hugging_face_api.py @@ -3,6 +3,7 @@ # SPDX-License-Identifier: Apache-2.0 import os from unittest.mock import MagicMock, Mock, patch +from datetime import datetime import pytest from huggingface_hub import ( @@ -312,3 +313,25 @@ def test_run_serverless(self): assert isinstance(response["meta"], list) assert len(response["meta"]) > 0 assert [isinstance(meta, dict) for meta in response["meta"]] + + @pytest.mark.integration + @pytest.mark.skipif( + not os.environ.get("HF_API_TOKEN", None), + reason="Export an env var called HF_API_TOKEN containing the Hugging Face token to run this test.", + ) + def test_live_run_streaming_check_completion_start_time(self): + generator = HuggingFaceAPIGenerator( + api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, + api_params={"model": "HuggingFaceH4/zephyr-7b-beta"}, + streaming_callback=streaming_callback_handler, + ) + + results = generator.run("What is the capital of France?") + + assert len(results["replies"]) == 1 + assert "Paris" in results["replies"][0] + + # Verify completion start time in final metadata + assert "completion_start_time" in results["meta"][0] + completion_start = datetime.fromisoformat(results["meta"][0]["completion_start_time"]) + assert completion_start <= datetime.now() diff --git a/test/components/generators/test_openai.py b/test/components/generators/test_openai.py index e1d865c95f..816412ee97 100644 --- a/test/components/generators/test_openai.py +++ b/test/components/generators/test_openai.py @@ -1,6 +1,7 @@ # SPDX-FileCopyrightText: 2022-present deepset GmbH # # SPDX-License-Identifier: Apache-2.0 +from datetime import datetime import logging import os from typing import List @@ -286,6 +287,9 @@ def __call__(self, chunk: StreamingChunk) -> None: assert "gpt-4o-mini" in metadata["model"] assert metadata["finish_reason"] == "stop" + assert "completion_start_time" in metadata + assert datetime.fromisoformat(metadata["completion_start_time"]) <= datetime.now() + # unfortunately, the usage is not available for streaming calls # we keep the key in the metadata for compatibility assert "usage" in metadata and len(metadata["usage"]) == 0