From 7f4ce0d644429df8e82d50bf98dd0f84fad8b29a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pavel=20Ti=C5=A1novsk=C3=BD?= Date: Mon, 27 Jan 2025 11:22:18 +0100 Subject: [PATCH] Pass streaming parameter to LLM loader (#317) * Pass streaming parameter to LLM loader Signed-off-by: Pavel Tisnovsky * Updated tests accordingly Signed-off-by: Pavel Tisnovsky * Merge changes into unit test for DocsSummarizer --------- Signed-off-by: Pavel Tisnovsky --- ols/app/endpoints/health.py | 2 +- ols/app/endpoints/ols.py | 1 + ols/src/llms/llm_loader.py | 6 ++- ols/src/query_helpers/docs_summarizer.py | 2 +- ols/src/query_helpers/query_helper.py | 4 +- ols/src/query_helpers/question_validator.py | 4 +- tests/mock_classes/mock_llm_loader.py | 2 +- .../query_helpers/test_docs_summarizer.py | 40 ++++++++++++++----- tests/unit/query_helpers/test_query_helper.py | 10 +++++ .../query_helpers/test_question_validator.py | 20 ++++++++-- 10 files changed, 72 insertions(+), 19 deletions(-) diff --git a/ols/app/endpoints/health.py b/ols/app/endpoints/health.py index 4a0cfd46..9dd51d59 100644 --- a/ols/app/endpoints/health.py +++ b/ols/app/endpoints/health.py @@ -45,7 +45,7 @@ def llm_is_ready() -> bool: llm_is_ready_persistent_state = False try: bare_llm = load_llm( - config.ols_config.default_provider, config.ols_config.default_model + config.ols_config.default_provider, config.ols_config.default_model, False ) response = bare_llm.invoke(input="Hello there!") # BAM and Watsonx replies as str and not as `AIMessage` diff --git a/ols/app/endpoints/ols.py b/ols/app/endpoints/ols.py index 33fb6aa9..3a4f9793 100644 --- a/ols/app/endpoints/ols.py +++ b/ols/app/endpoints/ols.py @@ -376,6 +376,7 @@ def generate_response( provider=llm_request.provider, model=llm_request.model, system_prompt=llm_request.system_prompt, + streaming=streaming, ) history = CacheEntry.cache_entries_to_history(previous_input) if streaming: diff --git a/ols/src/llms/llm_loader.py b/ols/src/llms/llm_loader.py index ca3ea09d..0f74b08a 100644 --- a/ols/src/llms/llm_loader.py +++ b/ols/src/llms/llm_loader.py @@ -53,7 +53,10 @@ def resolve_provider_config( def load_llm( - provider: str, model: str, generic_llm_params: Optional[dict] = None + provider: str, + model: str, + generic_llm_params: Optional[dict] = None, + streaming: Optional[bool] = None, ) -> LLM: """Load LLM according to input provider and model. @@ -61,6 +64,7 @@ def load_llm( provider: The provider name. model: The model name. generic_llm_params: The optional parameters that will be converted into LLM-specific ones. + streaming: The optional parameter that enable streaming on LLM side if set to True. Raises: LLMConfigurationError: If the whole provider configuration is missing. diff --git a/ols/src/query_helpers/docs_summarizer.py b/ols/src/query_helpers/docs_summarizer.py index 71eb19cf..8049db93 100644 --- a/ols/src/query_helpers/docs_summarizer.py +++ b/ols/src/query_helpers/docs_summarizer.py @@ -42,7 +42,7 @@ def _prepare_llm(self) -> None: GenericLLMParameters.MAX_TOKENS_FOR_RESPONSE: self.model_config.parameters.max_tokens_for_response # noqa: E501 } self.bare_llm = self.llm_loader( - self.provider, self.model, self.generic_llm_params + self.provider, self.model, self.generic_llm_params, self.streaming ) def _get_system_prompt(self) -> None: diff --git a/ols/src/query_helpers/query_helper.py b/ols/src/query_helpers/query_helper.py index 747df6de..6fe5a53b 100644 --- a/ols/src/query_helpers/query_helper.py +++ b/ols/src/query_helpers/query_helper.py @@ -21,8 +21,9 @@ def __init__( provider: Optional[str] = None, model: Optional[str] = None, generic_llm_params: Optional[dict] = None, - llm_loader: Optional[Callable[[str, str, dict], LLM]] = None, + llm_loader: Optional[Callable[[str, str, dict, bool], LLM]] = None, system_prompt: Optional[str] = None, + streaming: Optional[bool] = None, ) -> None: """Initialize query helper.""" # NOTE: As signature of this method is evaluated before the config, @@ -32,6 +33,7 @@ def __init__( self.model = model or config.ols_config.default_model self.generic_llm_params = generic_llm_params or {} self.llm_loader = llm_loader or load_llm + self.streaming = streaming or False self._system_prompt = ( (config.dev_config.enable_system_prompt_override and system_prompt) diff --git a/ols/src/query_helpers/question_validator.py b/ols/src/query_helpers/question_validator.py index 478095ba..a5afc9e0 100644 --- a/ols/src/query_helpers/question_validator.py +++ b/ols/src/query_helpers/question_validator.py @@ -57,7 +57,9 @@ def validate_question( prompts.QUESTION_VALIDATOR_PROMPT_TEMPLATE ) - bare_llm = self.llm_loader(self.provider, self.model, self.generic_llm_params) + bare_llm = self.llm_loader( + self.provider, self.model, self.generic_llm_params, self.streaming + ) # Tokens-check: We trigger the computation of the token count # without care about the return value. This is to ensure that diff --git a/tests/mock_classes/mock_llm_loader.py b/tests/mock_classes/mock_llm_loader.py index 5b6f8b9e..a2a23f03 100644 --- a/tests/mock_classes/mock_llm_loader.py +++ b/tests/mock_classes/mock_llm_loader.py @@ -27,7 +27,7 @@ def loader(*args, **kwargs): # if expected params are provided, check if (mocked) LLM loader # was called with expected parameters if expected_params is not None: - assert expected_params == args + assert expected_params == args, expected_params return MockLLMLoader(llm) return loader diff --git a/tests/unit/query_helpers/test_docs_summarizer.py b/tests/unit/query_helpers/test_docs_summarizer.py index 91ec3e8d..c82d6347 100644 --- a/tests/unit/query_helpers/test_docs_summarizer.py +++ b/tests/unit/query_helpers/test_docs_summarizer.py @@ -7,15 +7,25 @@ from langchain_core.messages import AIMessage, HumanMessage from ols import config -from ols.app.models.config import LoggingConfig -from ols.src.query_helpers.docs_summarizer import DocsSummarizer, QueryHelper -from ols.utils import suid -from ols.utils.logging_configurator import configure_logging -from tests import constants -from tests.mock_classes.mock_langchain_interface import mock_langchain_interface -from tests.mock_classes.mock_llama_index import MockLlamaIndex -from tests.mock_classes.mock_llm_chain import mock_llm_chain -from tests.mock_classes.mock_llm_loader import mock_llm_loader + +# needs to be setup there before is_user_authorized is imported +config.ols_config.authentication_config.module = "k8s" + + +from ols.app.models.config import LoggingConfig # noqa:E402 +from ols.src.query_helpers.docs_summarizer import ( # noqa:E402 + DocsSummarizer, + QueryHelper, +) +from ols.utils import suid # noqa:E402 +from ols.utils.logging_configurator import configure_logging # noqa:E402 +from tests import constants # noqa:E402 +from tests.mock_classes.mock_langchain_interface import ( # noqa:E402 + mock_langchain_interface, +) +from tests.mock_classes.mock_llama_index import MockLlamaIndex # noqa:E402 +from tests.mock_classes.mock_llm_chain import mock_llm_chain # noqa:E402 +from tests.mock_classes.mock_llm_loader import mock_llm_loader # noqa:E402 conversation_id = suid.get_suid() @@ -51,6 +61,18 @@ def test_if_system_prompt_was_updated(): assert summarizer.system_prompt == expected_prompt +def test_docs_summarizer_streaming_parameter(): + """Test if optional streaming parameter is stored.""" + summarizer = DocsSummarizer(llm_loader=mock_llm_loader(None)) + assert summarizer.streaming is False + + summarizer = DocsSummarizer(llm_loader=mock_llm_loader(None), streaming=False) + assert summarizer.streaming is False + + summarizer = DocsSummarizer(llm_loader=mock_llm_loader(None), streaming=True) + assert summarizer.streaming is True + + @patch("ols.utils.token_handler.RAG_SIMILARITY_CUTOFF", 0.4) @patch("ols.utils.token_handler.MINIMUM_CONTEXT_TOKEN_LIMIT", 1) @patch("ols.src.query_helpers.docs_summarizer.LLMChain", new=mock_llm_chain(None)) diff --git a/tests/unit/query_helpers/test_query_helper.py b/tests/unit/query_helpers/test_query_helper.py index 35ddb314..19431152 100644 --- a/tests/unit/query_helpers/test_query_helper.py +++ b/tests/unit/query_helpers/test_query_helper.py @@ -15,6 +15,7 @@ def test_defaults_used(): assert qh.model == config.ols_config.default_model assert qh.llm_loader is load_llm assert qh.generic_llm_params == {} + assert qh.streaming is False def test_inputs_are_used(): @@ -25,3 +26,12 @@ def test_inputs_are_used(): assert qh.provider == test_provider assert qh.model == test_model + + +def test_streaming_parameter(): + """Test that the optional streaming parameter is stored.""" + qh = QueryHelper(streaming=False) + assert qh.streaming is False + + qh = QueryHelper(streaming=True) + assert qh.streaming is True diff --git a/tests/unit/query_helpers/test_question_validator.py b/tests/unit/query_helpers/test_question_validator.py index 4d338814..9939ce58 100644 --- a/tests/unit/query_helpers/test_question_validator.py +++ b/tests/unit/query_helpers/test_question_validator.py @@ -6,9 +6,16 @@ from ols import config from ols.constants import GenericLLMParameters -from ols.src.query_helpers.question_validator import QueryHelper, QuestionValidator -from tests.mock_classes.mock_llm_chain import mock_llm_chain -from tests.mock_classes.mock_llm_loader import mock_llm_loader + +# needs to be setup there before is_user_authorized is imported +config.ols_config.authentication_config.module = "k8s" + +from ols.src.query_helpers.question_validator import ( # noqa: E402 + QueryHelper, + QuestionValidator, +) +from tests.mock_classes.mock_llm_chain import mock_llm_chain # noqa: E402 +from tests.mock_classes.mock_llm_loader import mock_llm_loader # noqa: E402 @pytest.fixture @@ -67,7 +74,12 @@ def test_validate_question_llm_loader(): # be performed llm_loader = mock_llm_loader( None, - expected_params=("p1", "m1", {GenericLLMParameters.MAX_TOKENS_FOR_RESPONSE: 4}), + expected_params=( + "p1", + "m1", + {GenericLLMParameters.MAX_TOKENS_FOR_RESPONSE: 4}, + False, + ), ) # check that LLM loader was called with expected parameters