Skip to content

Commit

Permalink
Updated tests accordingly
Browse files Browse the repository at this point in the history
Signed-off-by: Pavel Tisnovsky <[email protected]>
  • Loading branch information
tisnik committed Jan 27, 2025
1 parent ae072eb commit b9bc7eb
Show file tree
Hide file tree
Showing 6 changed files with 32 additions and 8 deletions.
2 changes: 1 addition & 1 deletion ols/app/endpoints/health.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def llm_is_ready() -> bool:
llm_is_ready_persistent_state = False
try:
bare_llm = load_llm(
config.ols_config.default_provider, config.ols_config.default_model
config.ols_config.default_provider, config.ols_config.default_model, False
)
response = bare_llm.invoke(input="Hello there!")
# BAM and Watsonx replies as str and not as `AIMessage`
Expand Down
2 changes: 1 addition & 1 deletion ols/src/query_helpers/query_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def __init__(
provider: Optional[str] = None,
model: Optional[str] = None,
generic_llm_params: Optional[dict] = None,
llm_loader: Optional[Callable[[str, str, dict], LLM]] = None,
llm_loader: Optional[Callable[[str, str, dict, bool], LLM]] = None,
system_prompt: Optional[str] = None,
streaming: Optional[bool] = None,
) -> None:
Expand Down
4 changes: 3 additions & 1 deletion ols/src/query_helpers/question_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,9 @@ def validate_question(
prompts.QUESTION_VALIDATOR_PROMPT_TEMPLATE
)

bare_llm = self.llm_loader(self.provider, self.model, self.generic_llm_params)
bare_llm = self.llm_loader(
self.provider, self.model, self.generic_llm_params, self.streaming
)

# Tokens-check: We trigger the computation of the token count
# without care about the return value. This is to ensure that
Expand Down
2 changes: 1 addition & 1 deletion tests/mock_classes/mock_llm_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def loader(*args, **kwargs):
# if expected params are provided, check if (mocked) LLM loader
# was called with expected parameters
if expected_params is not None:
assert expected_params == args
assert expected_params == args, expected_params
return MockLLMLoader(llm)

return loader
10 changes: 10 additions & 0 deletions tests/unit/query_helpers/test_query_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ def test_defaults_used():
assert qh.model == config.ols_config.default_model
assert qh.llm_loader is load_llm
assert qh.generic_llm_params == {}
assert qh.streaming is False


def test_inputs_are_used():
Expand All @@ -25,3 +26,12 @@ def test_inputs_are_used():

assert qh.provider == test_provider
assert qh.model == test_model


def test_streaming_parameter():
"""Test that the optional streaming parameter is stored."""
qh = QueryHelper(streaming=False)
assert qh.streaming is False

qh = QueryHelper(streaming=True)
assert qh.streaming is True
20 changes: 16 additions & 4 deletions tests/unit/query_helpers/test_question_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,16 @@

from ols import config
from ols.constants import GenericLLMParameters
from ols.src.query_helpers.question_validator import QueryHelper, QuestionValidator
from tests.mock_classes.mock_llm_chain import mock_llm_chain
from tests.mock_classes.mock_llm_loader import mock_llm_loader

# needs to be setup there before is_user_authorized is imported
config.ols_config.authentication_config.module = "k8s"

from ols.src.query_helpers.question_validator import ( # noqa: E402
QueryHelper,
QuestionValidator,
)
from tests.mock_classes.mock_llm_chain import mock_llm_chain # noqa: E402
from tests.mock_classes.mock_llm_loader import mock_llm_loader # noqa: E402


@pytest.fixture
Expand Down Expand Up @@ -67,7 +74,12 @@ def test_validate_question_llm_loader():
# be performed
llm_loader = mock_llm_loader(
None,
expected_params=("p1", "m1", {GenericLLMParameters.MAX_TOKENS_FOR_RESPONSE: 4}),
expected_params=(
"p1",
"m1",
{GenericLLMParameters.MAX_TOKENS_FOR_RESPONSE: 4},
False,
),
)

# check that LLM loader was called with expected parameters
Expand Down

0 comments on commit b9bc7eb

Please sign in to comment.