Skip to content

Commit

Permalink
Replacing chain with simple LLM
Browse files Browse the repository at this point in the history
Signed-off-by: Jiri Podivin <[email protected]>
  • Loading branch information
jpodivin committed Feb 10, 2025
1 parent 734cc91 commit 085294e
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 18 deletions.
20 changes: 5 additions & 15 deletions ols/src/query_helpers/question_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,6 @@
import logging
from typing import Any

from langchain.chains import LLMChain
from langchain.prompts import PromptTemplate

from ols import config
from ols.app.metrics import TokenMetricUpdater
from ols.constants import SUBJECT_REJECTED, GenericLLMParameters
Expand Down Expand Up @@ -53,10 +50,9 @@ def validate_question(
)
logger.info("%s call settings: %s", conversation_id, settings_string)

prompt_instructions = PromptTemplate.from_template(
prompts.QUESTION_VALIDATOR_PROMPT_TEMPLATE
prompt_instructions = prompts.QUESTION_VALIDATOR_PROMPT_TEMPLATE.replace(
"{query}", query
)

bare_llm = self.llm_loader(
self.provider, self.model, self.generic_llm_params, self.streaming
)
Expand All @@ -70,23 +66,17 @@ def validate_question(
query, model_config.context_window_size, self.max_tokens_for_response
)

llm_chain = LLMChain(
llm=bare_llm,
prompt=prompt_instructions,
verbose=verbose,
)

logger.debug("%s validating user query: %s", conversation_id, query)

with TokenMetricUpdater(
llm=bare_llm,
provider=provider_config.type,
model=self.model,
) as generic_token_counter:
response = llm_chain.invoke(
input={"query": query}, config={"callbacks": [generic_token_counter]}
response = bare_llm.invoke(
input=prompt_instructions, config={"callbacks": [generic_token_counter]}
)
clean_response = str(response["text"]).strip()
clean_response = response.strip()

logger.debug(
"%s query validation response: %s", conversation_id, clean_response
Expand Down
4 changes: 4 additions & 0 deletions tests/mock_classes/mock_llm_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@ async def astream(self, llm_input, **kwargs):
# yield input prompt/user query
yield llm_input[1].content

def invoke(self, input, config=None, **kwargs):
"""Transform a single input into an output."""
return input


def mock_llm_loader(llm=None, expected_params=None):
"""Construct mock for load_llm."""
Expand Down
3 changes: 0 additions & 3 deletions tests/unit/query_helpers/test_question_validator.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
"""Unit tests for QuestionValidator class."""

from unittest.mock import patch

import pytest

Expand All @@ -14,7 +13,6 @@
QueryHelper,
QuestionValidator,
)
from tests.mock_classes.mock_llm_chain import mock_llm_chain # noqa: E402
from tests.mock_classes.mock_llm_loader import mock_llm_loader # noqa: E402


Expand Down Expand Up @@ -63,7 +61,6 @@ def test_passing_parameters():
)


@patch("ols.src.query_helpers.question_validator.LLMChain", new=mock_llm_chain(None))
def test_validate_question_llm_loader():
"""Test that LLM is loaded within validate_question method with proper parameters."""
# it is needed to initialize configuration in order to be able
Expand Down

0 comments on commit 085294e

Please sign in to comment.