Skip to content

Commit

Permalink
Pass streaming parameter to LLM loader (#317)
Browse files Browse the repository at this point in the history
* Pass streaming parameter to LLM loader

Signed-off-by: Pavel Tisnovsky <[email protected]>

* Updated tests accordingly

Signed-off-by: Pavel Tisnovsky <[email protected]>

* Merge changes into unit test for DocsSummarizer

---------

Signed-off-by: Pavel Tisnovsky <[email protected]>
  • Loading branch information
tisnik authored Jan 27, 2025
1 parent 540a694 commit 7f4ce0d
Show file tree
Hide file tree
Showing 10 changed files with 72 additions and 19 deletions.
2 changes: 1 addition & 1 deletion ols/app/endpoints/health.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def llm_is_ready() -> bool:
llm_is_ready_persistent_state = False
try:
bare_llm = load_llm(
config.ols_config.default_provider, config.ols_config.default_model
config.ols_config.default_provider, config.ols_config.default_model, False
)
response = bare_llm.invoke(input="Hello there!")
# BAM and Watsonx replies as str and not as `AIMessage`
Expand Down
1 change: 1 addition & 0 deletions ols/app/endpoints/ols.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,7 @@ def generate_response(
provider=llm_request.provider,
model=llm_request.model,
system_prompt=llm_request.system_prompt,
streaming=streaming,
)
history = CacheEntry.cache_entries_to_history(previous_input)
if streaming:
Expand Down
6 changes: 5 additions & 1 deletion ols/src/llms/llm_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,14 +53,18 @@ def resolve_provider_config(


def load_llm(
provider: str, model: str, generic_llm_params: Optional[dict] = None
provider: str,
model: str,
generic_llm_params: Optional[dict] = None,
streaming: Optional[bool] = None,
) -> LLM:
"""Load LLM according to input provider and model.
Args:
provider: The provider name.
model: The model name.
generic_llm_params: The optional parameters that will be converted into LLM-specific ones.
streaming: The optional parameter that enable streaming on LLM side if set to True.
Raises:
LLMConfigurationError: If the whole provider configuration is missing.
Expand Down
2 changes: 1 addition & 1 deletion ols/src/query_helpers/docs_summarizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def _prepare_llm(self) -> None:
GenericLLMParameters.MAX_TOKENS_FOR_RESPONSE: self.model_config.parameters.max_tokens_for_response # noqa: E501
}
self.bare_llm = self.llm_loader(
self.provider, self.model, self.generic_llm_params
self.provider, self.model, self.generic_llm_params, self.streaming
)

def _get_system_prompt(self) -> None:
Expand Down
4 changes: 3 additions & 1 deletion ols/src/query_helpers/query_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,9 @@ def __init__(
provider: Optional[str] = None,
model: Optional[str] = None,
generic_llm_params: Optional[dict] = None,
llm_loader: Optional[Callable[[str, str, dict], LLM]] = None,
llm_loader: Optional[Callable[[str, str, dict, bool], LLM]] = None,
system_prompt: Optional[str] = None,
streaming: Optional[bool] = None,
) -> None:
"""Initialize query helper."""
# NOTE: As signature of this method is evaluated before the config,
Expand All @@ -32,6 +33,7 @@ def __init__(
self.model = model or config.ols_config.default_model
self.generic_llm_params = generic_llm_params or {}
self.llm_loader = llm_loader or load_llm
self.streaming = streaming or False

self._system_prompt = (
(config.dev_config.enable_system_prompt_override and system_prompt)
Expand Down
4 changes: 3 additions & 1 deletion ols/src/query_helpers/question_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,9 @@ def validate_question(
prompts.QUESTION_VALIDATOR_PROMPT_TEMPLATE
)

bare_llm = self.llm_loader(self.provider, self.model, self.generic_llm_params)
bare_llm = self.llm_loader(
self.provider, self.model, self.generic_llm_params, self.streaming
)

# Tokens-check: We trigger the computation of the token count
# without care about the return value. This is to ensure that
Expand Down
2 changes: 1 addition & 1 deletion tests/mock_classes/mock_llm_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def loader(*args, **kwargs):
# if expected params are provided, check if (mocked) LLM loader
# was called with expected parameters
if expected_params is not None:
assert expected_params == args
assert expected_params == args, expected_params
return MockLLMLoader(llm)

return loader
40 changes: 31 additions & 9 deletions tests/unit/query_helpers/test_docs_summarizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,25 @@
from langchain_core.messages import AIMessage, HumanMessage

from ols import config
from ols.app.models.config import LoggingConfig
from ols.src.query_helpers.docs_summarizer import DocsSummarizer, QueryHelper
from ols.utils import suid
from ols.utils.logging_configurator import configure_logging
from tests import constants
from tests.mock_classes.mock_langchain_interface import mock_langchain_interface
from tests.mock_classes.mock_llama_index import MockLlamaIndex
from tests.mock_classes.mock_llm_chain import mock_llm_chain
from tests.mock_classes.mock_llm_loader import mock_llm_loader

# needs to be setup there before is_user_authorized is imported
config.ols_config.authentication_config.module = "k8s"


from ols.app.models.config import LoggingConfig # noqa:E402
from ols.src.query_helpers.docs_summarizer import ( # noqa:E402
DocsSummarizer,
QueryHelper,
)
from ols.utils import suid # noqa:E402
from ols.utils.logging_configurator import configure_logging # noqa:E402
from tests import constants # noqa:E402
from tests.mock_classes.mock_langchain_interface import ( # noqa:E402
mock_langchain_interface,
)
from tests.mock_classes.mock_llama_index import MockLlamaIndex # noqa:E402
from tests.mock_classes.mock_llm_chain import mock_llm_chain # noqa:E402
from tests.mock_classes.mock_llm_loader import mock_llm_loader # noqa:E402

conversation_id = suid.get_suid()

Expand Down Expand Up @@ -51,6 +61,18 @@ def test_if_system_prompt_was_updated():
assert summarizer.system_prompt == expected_prompt


def test_docs_summarizer_streaming_parameter():
"""Test if optional streaming parameter is stored."""
summarizer = DocsSummarizer(llm_loader=mock_llm_loader(None))
assert summarizer.streaming is False

summarizer = DocsSummarizer(llm_loader=mock_llm_loader(None), streaming=False)
assert summarizer.streaming is False

summarizer = DocsSummarizer(llm_loader=mock_llm_loader(None), streaming=True)
assert summarizer.streaming is True


@patch("ols.utils.token_handler.RAG_SIMILARITY_CUTOFF", 0.4)
@patch("ols.utils.token_handler.MINIMUM_CONTEXT_TOKEN_LIMIT", 1)
@patch("ols.src.query_helpers.docs_summarizer.LLMChain", new=mock_llm_chain(None))
Expand Down
10 changes: 10 additions & 0 deletions tests/unit/query_helpers/test_query_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ def test_defaults_used():
assert qh.model == config.ols_config.default_model
assert qh.llm_loader is load_llm
assert qh.generic_llm_params == {}
assert qh.streaming is False


def test_inputs_are_used():
Expand All @@ -25,3 +26,12 @@ def test_inputs_are_used():

assert qh.provider == test_provider
assert qh.model == test_model


def test_streaming_parameter():
"""Test that the optional streaming parameter is stored."""
qh = QueryHelper(streaming=False)
assert qh.streaming is False

qh = QueryHelper(streaming=True)
assert qh.streaming is True
20 changes: 16 additions & 4 deletions tests/unit/query_helpers/test_question_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,16 @@

from ols import config
from ols.constants import GenericLLMParameters
from ols.src.query_helpers.question_validator import QueryHelper, QuestionValidator
from tests.mock_classes.mock_llm_chain import mock_llm_chain
from tests.mock_classes.mock_llm_loader import mock_llm_loader

# needs to be setup there before is_user_authorized is imported
config.ols_config.authentication_config.module = "k8s"

from ols.src.query_helpers.question_validator import ( # noqa: E402
QueryHelper,
QuestionValidator,
)
from tests.mock_classes.mock_llm_chain import mock_llm_chain # noqa: E402
from tests.mock_classes.mock_llm_loader import mock_llm_loader # noqa: E402


@pytest.fixture
Expand Down Expand Up @@ -67,7 +74,12 @@ def test_validate_question_llm_loader():
# be performed
llm_loader = mock_llm_loader(
None,
expected_params=("p1", "m1", {GenericLLMParameters.MAX_TOKENS_FOR_RESPONSE: 4}),
expected_params=(
"p1",
"m1",
{GenericLLMParameters.MAX_TOKENS_FOR_RESPONSE: 4},
False,
),
)

# check that LLM loader was called with expected parameters
Expand Down

0 comments on commit 7f4ce0d

Please sign in to comment.