From b9bc7eba9e56f21720af25ad5487fdf75afa531b Mon Sep 17 00:00:00 2001 From: Pavel Tisnovsky Date: Sun, 26 Jan 2025 14:53:02 +0100 Subject: [PATCH] Updated tests accordingly Signed-off-by: Pavel Tisnovsky --- ols/app/endpoints/health.py | 2 +- ols/src/query_helpers/query_helper.py | 2 +- ols/src/query_helpers/question_validator.py | 4 +++- tests/mock_classes/mock_llm_loader.py | 2 +- tests/unit/query_helpers/test_query_helper.py | 10 ++++++++++ .../query_helpers/test_question_validator.py | 20 +++++++++++++++---- 6 files changed, 32 insertions(+), 8 deletions(-) diff --git a/ols/app/endpoints/health.py b/ols/app/endpoints/health.py index 4a0cfd46..9dd51d59 100644 --- a/ols/app/endpoints/health.py +++ b/ols/app/endpoints/health.py @@ -45,7 +45,7 @@ def llm_is_ready() -> bool: llm_is_ready_persistent_state = False try: bare_llm = load_llm( - config.ols_config.default_provider, config.ols_config.default_model + config.ols_config.default_provider, config.ols_config.default_model, False ) response = bare_llm.invoke(input="Hello there!") # BAM and Watsonx replies as str and not as `AIMessage` diff --git a/ols/src/query_helpers/query_helper.py b/ols/src/query_helpers/query_helper.py index 12f208f0..6fe5a53b 100644 --- a/ols/src/query_helpers/query_helper.py +++ b/ols/src/query_helpers/query_helper.py @@ -21,7 +21,7 @@ def __init__( provider: Optional[str] = None, model: Optional[str] = None, generic_llm_params: Optional[dict] = None, - llm_loader: Optional[Callable[[str, str, dict], LLM]] = None, + llm_loader: Optional[Callable[[str, str, dict, bool], LLM]] = None, system_prompt: Optional[str] = None, streaming: Optional[bool] = None, ) -> None: diff --git a/ols/src/query_helpers/question_validator.py b/ols/src/query_helpers/question_validator.py index 478095ba..a5afc9e0 100644 --- a/ols/src/query_helpers/question_validator.py +++ b/ols/src/query_helpers/question_validator.py @@ -57,7 +57,9 @@ def validate_question( prompts.QUESTION_VALIDATOR_PROMPT_TEMPLATE ) - bare_llm = self.llm_loader(self.provider, self.model, self.generic_llm_params) + bare_llm = self.llm_loader( + self.provider, self.model, self.generic_llm_params, self.streaming + ) # Tokens-check: We trigger the computation of the token count # without care about the return value. This is to ensure that diff --git a/tests/mock_classes/mock_llm_loader.py b/tests/mock_classes/mock_llm_loader.py index 5b6f8b9e..a2a23f03 100644 --- a/tests/mock_classes/mock_llm_loader.py +++ b/tests/mock_classes/mock_llm_loader.py @@ -27,7 +27,7 @@ def loader(*args, **kwargs): # if expected params are provided, check if (mocked) LLM loader # was called with expected parameters if expected_params is not None: - assert expected_params == args + assert expected_params == args, expected_params return MockLLMLoader(llm) return loader diff --git a/tests/unit/query_helpers/test_query_helper.py b/tests/unit/query_helpers/test_query_helper.py index 35ddb314..19431152 100644 --- a/tests/unit/query_helpers/test_query_helper.py +++ b/tests/unit/query_helpers/test_query_helper.py @@ -15,6 +15,7 @@ def test_defaults_used(): assert qh.model == config.ols_config.default_model assert qh.llm_loader is load_llm assert qh.generic_llm_params == {} + assert qh.streaming is False def test_inputs_are_used(): @@ -25,3 +26,12 @@ def test_inputs_are_used(): assert qh.provider == test_provider assert qh.model == test_model + + +def test_streaming_parameter(): + """Test that the optional streaming parameter is stored.""" + qh = QueryHelper(streaming=False) + assert qh.streaming is False + + qh = QueryHelper(streaming=True) + assert qh.streaming is True diff --git a/tests/unit/query_helpers/test_question_validator.py b/tests/unit/query_helpers/test_question_validator.py index 4d338814..9939ce58 100644 --- a/tests/unit/query_helpers/test_question_validator.py +++ b/tests/unit/query_helpers/test_question_validator.py @@ -6,9 +6,16 @@ from ols import config from ols.constants import GenericLLMParameters -from ols.src.query_helpers.question_validator import QueryHelper, QuestionValidator -from tests.mock_classes.mock_llm_chain import mock_llm_chain -from tests.mock_classes.mock_llm_loader import mock_llm_loader + +# needs to be setup there before is_user_authorized is imported +config.ols_config.authentication_config.module = "k8s" + +from ols.src.query_helpers.question_validator import ( # noqa: E402 + QueryHelper, + QuestionValidator, +) +from tests.mock_classes.mock_llm_chain import mock_llm_chain # noqa: E402 +from tests.mock_classes.mock_llm_loader import mock_llm_loader # noqa: E402 @pytest.fixture @@ -67,7 +74,12 @@ def test_validate_question_llm_loader(): # be performed llm_loader = mock_llm_loader( None, - expected_params=("p1", "m1", {GenericLLMParameters.MAX_TOKENS_FOR_RESPONSE: 4}), + expected_params=( + "p1", + "m1", + {GenericLLMParameters.MAX_TOKENS_FOR_RESPONSE: 4}, + False, + ), ) # check that LLM loader was called with expected parameters