From 542543d62f82011e7226e253e750719d28bdcd0c Mon Sep 17 00:00:00 2001 From: Ondrej Metelka Date: Wed, 8 Jan 2025 10:22:08 +0100 Subject: [PATCH 1/4] Implement streaming response endpoint --- README.md | 2 + docs/openapi.json | 98 ++++- ols/app/endpoints/ols.py | 154 ++++--- ols/app/endpoints/streaming_ols.py | 358 ++++++++++++++++ ols/app/models/models.py | 9 + ols/app/routers.py | 3 +- ols/constants.py | 8 +- ols/src/query_helpers/docs_summarizer.py | 144 ++++--- tests/integration/test_ols.py | 393 ++++++++++-------- tests/mock_classes/mock_llm_loader.py | 5 + tests/unit/app/endpoints/test_ols.py | 40 +- .../unit/app/endpoints/test_streaming_ols.py | 131 ++++++ tests/unit/app/models/test_models.py | 19 + .../query_helpers/test_docs_summarizer.py | 30 +- 14 files changed, 1111 insertions(+), 283 deletions(-) create mode 100644 ols/app/endpoints/streaming_ols.py create mode 100644 tests/unit/app/endpoints/test_streaming_ols.py diff --git a/README.md b/README.md index b3d75d37..420128cf 100644 --- a/README.md +++ b/README.md @@ -609,6 +609,8 @@ To send a request to the server you can use the following curl command: curl -X 'POST' 'http://127.0.0.1:8080/v1/query' -H 'accept: application/json' -H 'Content-Type: application/json' -d '{"query": "write a deployment yaml for the mongodb image"}' ``` +> You can use `/v1/streaming_query` endpoint (with same parameters) to get the streaming response (SSE/HTTP chunking). By default, it streams text, but you can also yield events as JSONs via additionl `"media_type": "text/plain"` parameter in the payload data. + ### Swagger UI Web page with Swagger UI has the standard `/docs` endpoint. If the service is running on localhost on port 8080, Swagger UI can be accessed on the address `http://localhost:8080/docs`. diff --git a/docs/openapi.json b/docs/openapi.json index 65d1c44b..72d402a8 100644 --- a/docs/openapi.json +++ b/docs/openapi.json @@ -88,6 +88,89 @@ } } }, + "/v1/streaming_query": { + "post": { + "tags": [ + "streaming_query" + ], + "summary": "Conversation Request", + "description": "Handle conversation requests for the OLS endpoint.\n\nArgs:\n llm_request: The incoming request containing query details.\n auth: The authentication context, provided by dependency injection.\n\nReturns:\n StreamingResponse: The streaming response generated for the query.", + "operationId": "conversation_request_v1_streaming_query_post", + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/LLMRequest" + } + } + }, + "required": true + }, + "responses": { + "200": { + "description": "Query is valid and stream/events from endpoint is returned", + "content": { + "application/json": { + "schema": { + "type": "string", + "title": "Response 200 Conversation Request V1 Streaming Query Post" + } + } + } + }, + "401": { + "description": "Missing or invalid credentials provided by client", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/UnauthorizedResponse" + } + } + } + }, + "403": { + "description": "Client does not have permission to access resource", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ForbiddenResponse" + } + } + } + }, + "413": { + "description": "Prompt is too long", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/PromptTooLongResponse" + } + } + } + }, + "500": { + "description": "Query can not be validated, LLM is not accessible or other internal error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ErrorResponse" + } + } + } + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + } + } + } + } + }, "/v1/feedback/status": { "get": { "tags": [ @@ -579,6 +662,18 @@ } ], "title": "Attachments" + }, + "media_type": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Media Type", + "default": "text/plain" } }, "additionalProperties": false, @@ -587,7 +682,7 @@ "query" ], "title": "LLMRequest", - "description": "Model representing a request for the LLM (Language Model) send into OLS service.\n\nAttributes:\n query: The query string.\n conversation_id: The optional conversation ID (UUID).\n provider: The optional provider.\n model: The optional model.\n attachments: The optional attachments.\n\nExample:\n ```python\n llm_request = LLMRequest(query=\"Tell me about Kubernetes\")\n ```", + "description": "Model representing a request for the LLM (Language Model) send into OLS service.\n\nAttributes:\n query: The query string.\n conversation_id: The optional conversation ID (UUID).\n provider: The optional provider.\n model: The optional model.\n attachments: The optional attachments.\n media_type: The optional parameter for streaming response.\n\nExample:\n ```python\n llm_request = LLMRequest(query=\"Tell me about Kubernetes\")\n ```", "examples": [ { "attachments": [ @@ -608,6 +703,7 @@ } ], "conversation_id": "123e4567-e89b-12d3-a456-426614174000", + "media_type": "text/plain", "model": "gpt-4o-mini", "provider": "openai", "query": "write a deployment yaml for the mongodb image", diff --git a/ols/app/endpoints/ols.py b/ols/app/endpoints/ols.py index a27af640..98c62a4b 100644 --- a/ols/app/endpoints/ols.py +++ b/ols/app/endpoints/ols.py @@ -7,7 +7,7 @@ import time from datetime import datetime from pathlib import Path -from typing import Any, Optional +from typing import Any, Generator, Optional, Union import pytz from fastapi import APIRouter, Depends, HTTPException, status @@ -79,49 +79,15 @@ def conversation_request( Returns: Response containing the processed information. """ - timestamps: dict[str, float] = {} - timestamps["start"] = time.time() - - # Initialize variables - previous_input = [] - - user_id = retrieve_user_id(auth) - logger.info("User ID %s", user_id) - timestamps["retrieve user"] = time.time() - - conversation_id = retrieve_conversation_id(llm_request) - timestamps["retrieve conversation"] = time.time() - - # Important note: Redact the query before attempting to do any - # logging of the query to avoid leaking PII into logs. - - # Redact the query - llm_request = redact_query(conversation_id, llm_request) - timestamps["redact query"] = time.time() - - # Log incoming request (after redaction) - logger.info("%s Incoming request: %s", conversation_id, llm_request.query) - - previous_input = retrieve_previous_input(user_id, llm_request) - timestamps["retrieve previous input"] = time.time() - - # Retrieve attachments from the request - attachments = retrieve_attachments(llm_request) - - # Redact all attachments - attachments = redact_attachments(conversation_id, attachments) - - # All attachments should be appended to query - but store original - # query for later use in transcript storage - query_without_attachments = llm_request.query - llm_request.query = append_attachments_to_query(llm_request.query, attachments) - timestamps["append attachments"] = time.time() - - validate_requested_provider_model(llm_request) - - # Validate the query - valid = validate_question(conversation_id, llm_request) - timestamps["validate question"] = time.time() + ( + user_id, + conversation_id, + query_without_attachments, + previous_input, + attachments, + valid, + timestamps, + ) = process_request(auth, llm_request) if not valid: summarizer_response = SummarizerResponse( @@ -132,7 +98,7 @@ def conversation_request( else: summarizer_response = generate_response( conversation_id, llm_request, previous_input - ) + ) # type: ignore[assignment] timestamps["generate response"] = time.time() @@ -172,6 +138,76 @@ def conversation_request( ) +def process_request( + auth: Any, llm_request: LLMRequest +) -> tuple[str, str, str, list[CacheEntry], list[Attachment], bool, dict[str, float]]: + """Process incoming request. + + Args: + auth: The Authentication handler (FastAPI Depends) that will handle authentication Logic. + llm_request: The request containing a query, conversation ID, and optional attachments. + + Returns: + Tuple containing the processed information. + User ID, conversation ID, query without attachments, previous input, + attachments, validation result and timestamps. + """ + timestamps = {"start": time.time()} + + user_id = retrieve_user_id(auth) + logger.info("User ID %s", user_id) + timestamps["retrieve user"] = time.time() + + conversation_id = retrieve_conversation_id(llm_request) + timestamps["retrieve conversation"] = time.time() + + # Important note: Redact the query before attempting to do any + # logging of the query to avoid leaking PII into logs. + + # Redact the query + llm_request = redact_query(conversation_id, llm_request) + timestamps["redact query"] = time.time() + + # Log incoming request (after redaction) + logger.info("%s Incoming request: %s", conversation_id, llm_request.query) + + previous_input = retrieve_previous_input(user_id, llm_request) + timestamps["retrieve previous input"] = time.time() + + # Retrieve attachments from the request + attachments = retrieve_attachments(llm_request) + + # Redact all attachments + attachments = redact_attachments(conversation_id, attachments) + + # All attachments should be appended to query - but store original + # query for later use in transcript storage + query_without_attachments = llm_request.query + llm_request.query = append_attachments_to_query(llm_request.query, attachments) + timestamps["append attachments"] = time.time() + + validate_requested_provider_model(llm_request) + + # Validate the query + if not previous_input: + valid = validate_question(conversation_id, llm_request) + else: + logger.debug("follow-up conversation - skipping question validation") + valid = True + + timestamps["validate question"] = time.time() + + return ( + user_id, + conversation_id, + query_without_attachments, + previous_input, + attachments, + valid, + timestamps, + ) + + def log_processing_durations(timestamps: dict[str, float]) -> None: """Log processing durations.""" @@ -285,9 +321,19 @@ def generate_response( conversation_id: str, llm_request: LLMRequest, previous_input: list[CacheEntry], -) -> SummarizerResponse: - """Generate response based on validation result, previous input, and model output.""" - # Summarize documentation + streaming: bool = False, +) -> Union[SummarizerResponse, Generator]: + """Generate response based on validation result, previous input, and model output. + + Args: + conversation_id: The unique identifier for the conversation. + llm_request: The request containing a query. + previous_input: The history of the conversation (if available). + streaming: The flag indicating if the response should be streamed. + + Returns: + SummarizerResponse or Generator, depending on the streaming flag. + """ try: docs_summarizer = DocsSummarizer( provider=llm_request.provider, @@ -295,9 +341,15 @@ def generate_response( system_prompt=llm_request.system_prompt, ) history = CacheEntry.cache_entries_to_history(previous_input) - return docs_summarizer.summarize( - conversation_id, llm_request.query, config.rag_index, history + if streaming: + return docs_summarizer.generate_response( + llm_request.query, config.rag_index, history + ) + response = docs_summarizer.create_response( + llm_request.query, config.rag_index, history ) + logger.debug("%s Generated response: %s", conversation_id, response) + return response except PromptTooLongError as summarizer_error: logger.error("Prompt is too long: %s", summarizer_error) raise HTTPException( @@ -310,7 +362,7 @@ def generate_response( except Exception as summarizer_error: logger.error("Error while obtaining answer for user question") logger.exception(summarizer_error) - status_code, response, cause = errors_parsing.parse_generic_llm_error( + status_code, response, cause = errors_parsing.parse_generic_llm_error( # type: ignore[assignment] summarizer_error ) raise HTTPException( diff --git a/ols/app/endpoints/streaming_ols.py b/ols/app/endpoints/streaming_ols.py new file mode 100644 index 00000000..58779850 --- /dev/null +++ b/ols/app/endpoints/streaming_ols.py @@ -0,0 +1,358 @@ +"""FastAPI endpoint for the OLS streaming query. + +This module defines the endpoint and supporting functions for handling +streaming queries. +""" + +import json +import logging +import time +from typing import Any, AsyncGenerator + +from fastapi import APIRouter, Depends +from fastapi.responses import StreamingResponse + +from ols import config, constants +from ols.app.endpoints.ols import ( + generate_response, + log_processing_durations, + process_request, + store_conversation_history, + store_transcript, +) +from ols.app.models.models import ( + Attachment, + ErrorResponse, + ForbiddenResponse, + LLMRequest, + PromptTooLongResponse, + RagChunk, + ReferencedDocument, + SummarizerResponse, + UnauthorizedResponse, +) +from ols.constants import MEDIA_TYPE_TEXT +from ols.src.auth.auth import get_auth_dependency +from ols.utils import errors_parsing +from ols.utils.token_handler import PromptTooLongError + +logger = logging.getLogger(__name__) + +router = APIRouter(tags=["streaming_query"]) +auth_dependency = get_auth_dependency(config.ols_config, virtual_path="/ols-access") + + +query_responses: dict[int | str, dict[str, Any]] = { + 200: { + "description": "Query is valid and stream/events from endpoint is returned", + "model": str, + }, + 401: { + "description": "Missing or invalid credentials provided by client", + "model": UnauthorizedResponse, + }, + 403: { + "description": "Client does not have permission to access resource", + "model": ForbiddenResponse, + }, + 413: { + "description": "Prompt is too long", + "model": PromptTooLongResponse, + }, + 500: { + "description": "Query can not be validated, LLM is not accessible or other internal error", + "model": ErrorResponse, + }, +} + + +@router.post("/streaming_query", responses=query_responses) +def conversation_request( + llm_request: LLMRequest, auth: Any = Depends(auth_dependency) +) -> StreamingResponse: + """Handle conversation requests for the OLS endpoint. + + Args: + llm_request: The incoming request containing query details. + auth: The authentication context, provided by dependency injection. + + Returns: + StreamingResponse: The streaming response generated for the query. + """ + ( + user_id, + conversation_id, + query_without_attachments, + previous_input, + attachments, + valid, + timestamps, + ) = process_request(auth, llm_request) + + summarizer_response = ( + invalid_response_generator() + if not valid + else generate_response( + conversation_id, llm_request, previous_input, streaming=True + ) + ) + + return StreamingResponse( + response_processing_wrapper( + summarizer_response, + user_id, + conversation_id, + llm_request, + attachments, + valid, + query_without_attachments, + llm_request.media_type, + timestamps, + ), + media_type=llm_request.media_type, + ) + + +async def invalid_response_generator() -> AsyncGenerator[str, None]: + """Yield an invalid query response. + + Yields: + str: The response indicating invalid query. + """ + yield constants.INVALID_QUERY_RESP + + +def stream_start_event(conversation_id: str) -> str: + """Yield the start of the data stream. + + Args: + conversation_id: The conversation ID (UUID). + """ + return json.dumps( + { + "event": "start", + "data": { + "conversation_id": conversation_id, + }, + } + ) + + +def stream_end_event(ref_docs: list[dict], truncated: bool, media_type: str) -> str: + """Yield the end of the data stream. + + Args: + ref_docs: Referenced documents. + truncated: Indicates if the history was truncated. + media_type: Media type of the response (e.g. text or JSON). + """ + if media_type == constants.MEDIA_TYPE_JSON: + return json.dumps( + { + "event": "end", + "data": { + "referenced_documents": ref_docs, + "truncated": truncated, + }, + } + ) + ref_docs_string = "\n".join( + f'{item["doc_title"]}: {item["doc_url"]}' for item in ref_docs + ) + return f"\n\n---\n\n{ref_docs_string}" if ref_docs_string else "" + + +def build_referenced_docs(rag_chunks: list[RagChunk]) -> list[dict]: + """Build a list of unique referenced documents.""" + referenced_documents = ReferencedDocument.from_rag_chunks(rag_chunks) + return [ + { + "doc_title": doc.title, + "doc_url": doc.docs_url, + } + for doc in referenced_documents + ] + + +def prompt_too_long_error(error: PromptTooLongError, media_type: str) -> str: + """Return error representation for long prompts. + + Args: + error: The exception raised for long prompts. + media_type: Media type of the response (e.g. text or JSON). + + Returns: + str: The error message formatted for the media type. + """ + logger.error("Prompt is too long: %s", error) + if media_type == MEDIA_TYPE_TEXT: + return f"Prompt is too long: {error}" + return json.dumps( + { + "event": "error", + "data": { + "response": "Prompt is too long", + "cause": str(error), + }, + } + ) + + +def generic_llm_error(error: Exception, media_type: str) -> str: + """Return error representation for generic LLM errors. + + Args: + error: The exception raised during processing. + media_type: Media type of the response (e.g. text or JSON). + + Returns: + str: The error message formatted for the media type. + """ + logger.error("Error while obtaining answer for user question") + logger.exception(error) + _, response, cause = errors_parsing.parse_generic_llm_error(error) + + if media_type == MEDIA_TYPE_TEXT: + return f"{response}: {cause}" + return json.dumps( + { + "event": "error", + "data": { + "response": response, + "cause": cause, + }, + } + ) + + +def build_yield_item(item: str, idx: int, media_type: str) -> str: + """Build an item to yield based on media type. + + Args: + item: The token or string fragment to yield. + idx: Index of the current item in the stream. + media_type: Media type of the response (e.g. text or JSON). + + Returns: + str: The formatted string or JSON to yield. + """ + if media_type == MEDIA_TYPE_TEXT: + return item + return json.dumps({"event": "token", "data": {"id": idx, "token": item}}) + + +def store_data( + user_id: str, + conversation_id: str, + llm_request: LLMRequest, + response: str, + attachments: list[Attachment], + valid: bool, + query_without_attachments: str, + rag_chunks: list[RagChunk], + history_truncated: bool, + timestamps: dict[str, float], +) -> None: + """Store conversation history and transcript if enabled. + + Args: + user_id: The user ID (UUID). + conversation_id: The conversation ID (UUID). + llm_request: The original request. + response: The generated response. + attachments: list of attachments included in the query. + valid: Indicates if the query was valid. + query_without_attachments: Query content excluding attachments. + rag_chunks: list of RAG (Retrieve-And-Generate) chunks used in the response. + history_truncated: Indicates if the conversation history was truncated. + timestamps: Dictionary tracking timestamps for various stages. + """ + store_conversation_history( + user_id, conversation_id, llm_request, response, attachments + ) + + if not config.ols_config.user_data_collection.transcripts_disabled: + store_transcript( + user_id, + conversation_id, + valid, + query_without_attachments, + llm_request, + response, + rag_chunks, + history_truncated, + attachments, + ) + timestamps["store transcripts"] = time.time() + + +async def response_processing_wrapper( + generator: AsyncGenerator[Any, None], + user_id: str, + conversation_id: str, + llm_request: LLMRequest, + attachments: list[Attachment], + valid: bool, + query_without_attachments: str, + media_type: str, + timestamps: dict[str, float], +) -> AsyncGenerator[str, None]: + """Process the response from the generator and handle metadata and errors. + + Args: + generator: The async generator providing summarizer responses. + user_id: The user ID (UUID). + conversation_id: The conversation ID (UUID). + llm_request: The original request. + attachments: list of attachments included in the query. + valid: Indicates if the query was valid. + query_without_attachments: Query content excluding attachments. + media_type: Media type of the response (e.g. text or JSON). + timestamps: Dictionary tracking timestamps for various stages. + + Yields: + str: The response items or error messages. + """ + if media_type == constants.MEDIA_TYPE_JSON: + yield stream_start_event(conversation_id) + + response = "" + rag_chunks = [] + history_truncated = False + idx = 0 + try: + async for item in generator: + if isinstance(item, SummarizerResponse): + rag_chunks = item.rag_chunks + history_truncated = item.history_truncated + break + + response += item + yield build_yield_item(item, idx, media_type) + idx += 1 + except PromptTooLongError as summarizer_error: + yield prompt_too_long_error(summarizer_error, media_type) + except Exception as summarizer_error: + yield generic_llm_error(summarizer_error, media_type) + timestamps["generate response"] = time.time() + + store_data( + user_id, + conversation_id, + llm_request, + response, + attachments, + valid, + query_without_attachments, + rag_chunks, + history_truncated, + timestamps, + ) + + yield stream_end_event( + build_referenced_docs(rag_chunks), history_truncated, media_type + ) + + timestamps["add references"] = time.time() + + log_processing_durations(timestamps) diff --git a/ols/app/models/models.py b/ols/app/models/models.py index 6dad8094..818c5cdf 100644 --- a/ols/app/models/models.py +++ b/ols/app/models/models.py @@ -6,6 +6,7 @@ from pydantic import BaseModel, field_validator, model_validator from pydantic.dataclasses import dataclass +from ols.constants import MEDIA_TYPE_JSON, MEDIA_TYPE_TEXT from ols.customize import prompts from ols.utils import suid @@ -66,6 +67,7 @@ class LLMRequest(BaseModel): provider: The optional provider. model: The optional model. attachments: The optional attachments. + media_type: The optional parameter for streaming response. Example: ```python @@ -79,6 +81,7 @@ class LLMRequest(BaseModel): model: Optional[str] = None system_prompt: Optional[str] = None attachments: Optional[list[Attachment]] = None + media_type: Optional[str] = MEDIA_TYPE_TEXT # provides examples for /docs endpoint model_config = { @@ -108,6 +111,7 @@ class LLMRequest(BaseModel): "content": "foo: bar", }, ], + "media_type": "text/plain", } ] }, @@ -124,6 +128,11 @@ def validate_provider_and_model(self) -> Self: raise ValueError( "LLM model must be specified when the provider is specified." ) + if self.media_type not in (MEDIA_TYPE_TEXT, MEDIA_TYPE_JSON): + raise ValueError( + f"Invalid media type: '{self.media_type}', must be " + f"{MEDIA_TYPE_TEXT} or {MEDIA_TYPE_JSON}" + ) return self diff --git a/ols/app/routers.py b/ols/app/routers.py index 5c7d9f49..0096087e 100644 --- a/ols/app/routers.py +++ b/ols/app/routers.py @@ -2,7 +2,7 @@ from fastapi import FastAPI -from ols.app.endpoints import authorized, feedback, health, ols +from ols.app.endpoints import authorized, feedback, health, ols, streaming_ols from ols.app.metrics import metrics @@ -13,6 +13,7 @@ def include_routers(app: FastAPI) -> None: app: The `FastAPI` app instance. """ app.include_router(ols.router, prefix="/v1") + app.include_router(streaming_ols.router, prefix="/v1") app.include_router(feedback.router, prefix="/v1") app.include_router(health.router) app.include_router(metrics.router) diff --git a/ols/constants.py b/ols/constants.py index 6d8dd5a8..8bd6cd02 100644 --- a/ols/constants.py +++ b/ols/constants.py @@ -38,12 +38,10 @@ class QueryValidationMethod(StrEnum): PROVIDER_FAKE, } ) - DEFAULT_AZURE_API_VERSION = "2024-02-15-preview" -# models - +# models class ModelFamily(StrEnum): """Different LLM models family/group.""" @@ -222,3 +220,7 @@ class GenericLLMParameters: # Environment variable containing configuration file name to override default # configuration file CONFIGURATION_FILE_NAME_ENV_VARIABLE = "RCS_CONFIG_FILE" + +# Response streaming media types +MEDIA_TYPE_TEXT = "text/plain" +MEDIA_TYPE_JSON = "application/json" diff --git a/ols/src/query_helpers/docs_summarizer.py b/ols/src/query_helpers/docs_summarizer.py index 42069023..044a6dd5 100644 --- a/ols/src/query_helpers/docs_summarizer.py +++ b/ols/src/query_helpers/docs_summarizer.py @@ -1,17 +1,19 @@ """A class for summarizing documentation context.""" import logging -from typing import Any, Optional +from typing import Any, AsyncGenerator, Optional from langchain.chains import LLMChain +from langchain_core.prompts import ChatPromptTemplate from llama_index.core import VectorStoreIndex from ols import config from ols.app.metrics import TokenMetricUpdater -from ols.app.models.models import SummarizerResponse +from ols.app.models.models import RagChunk, SummarizerResponse from ols.constants import RAG_CONTENT_LIMIT, GenericLLMParameters from ols.customize import reranker from ols.src.prompts.prompt_generator import GeneratePrompt +from ols.src.prompts.prompts import QUERY_SYSTEM_INSTRUCTION from ols.src.query_helpers.query_helper import QueryHelper from ols.utils.token_handler import TokenHandler @@ -24,62 +26,72 @@ class DocsSummarizer(QueryHelper): def __init__(self, *args: Any, **kwargs: Any) -> None: """Initialize the QuestionValidator.""" super().__init__(*args, **kwargs) - provider_config = config.llm_config.providers.get(self.provider) - model_config = provider_config.models.get(self.model) + self._prepare_llm() + self._get_system_prompt() + self.verbose = config.ols_config.logging_config.app_log_level == logging.DEBUG + + def _prepare_llm(self) -> None: + """Prepare the LLM configuration.""" + self.provider_config = config.llm_config.providers.get(self.provider) + self.model_config = self.provider_config.models.get(self.model) self.generic_llm_params = { - GenericLLMParameters.MAX_TOKENS_FOR_RESPONSE: model_config.parameters.max_tokens_for_response # noqa: E501 + GenericLLMParameters.MAX_TOKENS_FOR_RESPONSE: self.model_config.parameters.max_tokens_for_response # noqa: E501 } + self.bare_llm = self.llm_loader( + self.provider, self.model, self.generic_llm_params + ) + + def _get_system_prompt(self) -> None: + """Retrieve the system prompt.""" + # use system prompt from config if available otherwise use + # default system prompt fine-tuned for the service + if config.ols_config.system_prompt is not None: + self.system_prompt = config.ols_config.system_prompt + else: + self.system_prompt = QUERY_SYSTEM_INSTRUCTION + logger.debug("System prompt: %s", self.system_prompt) - def summarize( + def _prepare_prompt( self, - conversation_id: str, query: str, vector_index: Optional[VectorStoreIndex] = None, history: Optional[list[str]] = None, - ) -> SummarizerResponse: + ) -> tuple[ChatPromptTemplate, dict[str, str], list[RagChunk], bool]: """Summarize the given query based on the provided conversation context. Args: - conversation_id: The unique identifier for the conversation. query: The query to be summarized. - vector_index: Vector index to get rag data/context. + vector_index: Vector index to get RAG data/context. history: The history of the conversation (if available). Returns: - A `SummarizerResponse` object. + A tuple containing the final prompt, input values, RAG chunks, + and a flag for truncated history. """ - # if history is not provided, initialize to empty history - if history is None: - history = [] - - verbose = config.ols_config.logging_config.app_log_level == logging.DEBUG - settings_string = ( - f"conversation_id: {conversation_id}, " f"query: {query}, " f"provider: {self.provider}, " f"model: {self.model}, " - f"verbose: {verbose}" + f"verbose: {self.verbose}" ) - logger.debug("%s call settings: %s", conversation_id, settings_string) + logger.debug("call settings: %s", settings_string) token_handler = TokenHandler() - bare_llm = self.llm_loader(self.provider, self.model, self.generic_llm_params) - provider_config = config.llm_config.providers.get(self.provider) - model_config = provider_config.models.get(self.model) - # Use sample text for context/history to get complete prompt instruction. - # This is used to calculate available tokens. + # Use sample text for context/history to get complete prompt + # instruction. This is used to calculate available tokens. temp_prompt, temp_prompt_input = GeneratePrompt( - query, ["sample"], ["ai: sample"], self._system_prompt + query, ["sample"], ["ai: sample"], self.system_prompt ).generate_prompt(self.model) + available_tokens = token_handler.calculate_and_check_available_tokens( temp_prompt.format(**temp_prompt_input), - model_config.context_window_size, - model_config.parameters.max_tokens_for_response, + self.model_config.context_window_size, + self.model_config.parameters.max_tokens_for_response, ) - if vector_index is not None: + # Retrieve RAG content + if vector_index: retriever = vector_index.as_retriever(similarity_top_k=RAG_CONTENT_LIMIT) retrieved_nodes = retriever.retrieve(query) retrieved_nodes = reranker.rerank(retrieved_nodes) @@ -89,16 +101,17 @@ def summarize( else: logger.warning("Proceeding without RAG content. Check start up messages.") rag_chunks = [] - rag_context = [rag_chunk.text for rag_chunk in rag_chunks] + if len(rag_context) == 0: + logger.debug("Using llm to answer the query without reference content") - # Truncate history, if applicable + # Truncate history history, truncated = token_handler.limit_conversation_history( - history, self.model, available_tokens + history or [], self.model, available_tokens ) final_prompt, llm_input_values = GeneratePrompt( - query, rag_context, history, self._system_prompt + query, rag_context, history, self.system_prompt ).generate_prompt(self.model) # Tokens-check: We trigger the computation of the token count @@ -106,19 +119,32 @@ def summarize( # the query is within the token limit. token_handler.calculate_and_check_available_tokens( final_prompt.format(**llm_input_values), - model_config.context_window_size, - model_config.parameters.max_tokens_for_response, + self.model_config.context_window_size, + self.model_config.parameters.max_tokens_for_response, + ) + + return final_prompt, llm_input_values, rag_chunks, truncated + + def create_response( + self, + query: str, + vector_index: Optional[VectorStoreIndex] = None, + history: Optional[list[str]] = None, + ) -> SummarizerResponse: + """Create a response for the given query based on the provided conversation context.""" + final_prompt, llm_input_values, rag_chunks, truncated = self._prepare_prompt( + query, vector_index, history ) chat_engine = LLMChain( - llm=bare_llm, + llm=self.bare_llm, prompt=final_prompt, - verbose=verbose, + verbose=self.verbose, ) with TokenMetricUpdater( - llm=bare_llm, - provider=provider_config.type, + llm=self.bare_llm, + provider=self.provider_config.type, model=self.model, ) as token_counter: summary = chat_engine.invoke( @@ -132,13 +158,37 @@ def summarize( # Recently watsonx/granite-13b started adding stop token to response. response = response.replace("<|endoftext|>", "") - if len(rag_context) == 0: - logger.debug("Using llm to answer the query without reference content") - logger.debug("%s Summary response: %s", conversation_id, response) - return SummarizerResponse(response, rag_chunks, truncated) - @property - def system_prompt(self) -> str: - """Return actually used system prompt.""" - return self._system_prompt + async def generate_response( + self, + query: str, + vector_index: Optional[VectorStoreIndex] = None, + history: Optional[list[str]] = None, + ) -> AsyncGenerator[str, SummarizerResponse]: + """Generate a response for the given query based on the provided conversation context.""" + final_prompt, llm_input_values, rag_chunks, truncated = self._prepare_prompt( + query, vector_index, history + ) + + with TokenMetricUpdater( + llm=self.bare_llm, + provider=self.provider_config.type, + model=self.model, + ) as token_counter: + async for chunk in self.bare_llm.astream( + final_prompt.format_prompt(**llm_input_values).to_messages(), + config={"callbacks": [token_counter]}, + ): + # TODO: it is bad to have provider specific code here + # the reason we have provider classes is to hide specific + # implementation details there. But it requires expanding + # the current providers interface, eg. to stream messages + + # openai returns an `AIMessageChunk` while Watsonx plain string + chunk_content = chunk.content if hasattr(chunk, "content") else chunk + if "<|endoftext|>" in chunk_content: + chunk_content = chunk_content.replace("<|endoftext|>", "") + yield chunk_content + + yield SummarizerResponse("", rag_chunks, truncated) # type: ignore[misc] diff --git a/tests/integration/test_ols.py b/tests/integration/test_ols.py index 390ba41c..13c4872c 100644 --- a/tests/integration/test_ols.py +++ b/tests/integration/test_ols.py @@ -6,7 +6,6 @@ import pytest import requests from fastapi.testclient import TestClient -from langchain.schema import AIMessage, HumanMessage from ols import config, constants from ols.app.models.config import ( @@ -34,9 +33,10 @@ def _setup(): pytest.client = TestClient(app) -def test_post_question_on_unexpected_payload(_setup): - """Check the REST API /v1/query with POST HTTP method when unexpected payload is posted.""" - response = pytest.client.post("/v1/query", json="this is really not proper payload") +@pytest.mark.parametrize("endpoint", ("/v1/query", "/v1/streaming_query")) +def test_post_question_on_unexpected_payload(_setup, endpoint): + """Check the REST API /v1/query when unexpected payload is posted.""" + response = pytest.client.post(endpoint, json="this is really not proper payload") assert response.status_code == requests.codes.unprocessable # try to deserialize payload @@ -58,10 +58,11 @@ def test_post_question_on_unexpected_payload(_setup): } -def test_post_question_without_payload(_setup): - """Check the REST API /v1/query with POST HTTP method when no payload is posted.""" +@pytest.mark.parametrize("endpoint", ("/v1/query", "/v1/streaming_query")) +def test_post_question_without_payload(_setup, endpoint): + """Check the REST API query endpoints when no payload is posted.""" # perform POST request without any payload - response = pytest.client.post("/v1/query") + response = pytest.client.post(endpoint) assert response.status_code == requests.codes.unprocessable # check the response payload @@ -72,43 +73,57 @@ def test_post_question_without_payload(_setup): assert "Field required" in detail["msg"] -def test_post_question_on_invalid_question(_setup): - """Check the REST API /v1/query with POST HTTP method for invalid question.""" +@pytest.mark.parametrize("endpoint", ("/v1/query", "/v1/streaming_query")) +def test_post_question_on_invalid_question(_setup, endpoint): + """Check the REST API /v1/query for invalid question.""" # let's pretend the question is invalid without even asking LLM with patch("ols.app.endpoints.ols.validate_question", return_value=False): conversation_id = suid.get_suid() response = pytest.client.post( - "/v1/query", + endpoint, json={"conversation_id": conversation_id, "query": "test query"}, ) assert response.status_code == requests.codes.ok - expected_json = { - "conversation_id": conversation_id, - "response": prompts.INVALID_QUERY_RESP, - "referenced_documents": [], - "truncated": False, - } - assert response.json() == expected_json + if response.headers["content-type"] == "application/json": + # non-streaming responses return JSON + expected_response = { + "conversation_id": conversation_id, + "response": prompts.INVALID_QUERY_RESP, + "referenced_documents": [], + "truncated": False, + } + actual_response = response.json() + else: + # streaming_query returns bytes + expected_response = prompts.INVALID_QUERY_RESP + actual_response = response.text + + assert actual_response == expected_response -def test_post_question_on_generic_response_type_summarize_error(_setup): - """Check the REST API /v1/query with POST HTTP method when generic response type is returned.""" +@pytest.mark.parametrize("endpoint", ("/v1/query", "/v1/streaming_query")) +def test_post_question_on_generic_response_type_summarize_error(_setup, endpoint): + """Check the REST API query endpoints when generic response type is returned.""" # let's pretend the question is valid and generic one answer = constants.SUBJECT_ALLOWED with ( patch( - "ols.app.endpoints.ols.QuestionValidator.validate_question", + "ols.src.query_helpers.question_validator.QuestionValidator.validate_question", return_value=answer, ), patch( - "ols.app.endpoints.ols.DocsSummarizer.summarize", + "ols.src.query_helpers.docs_summarizer.DocsSummarizer.create_response", + side_effect=Exception("summarizer error"), + ), + patch( + "ols.src.query_helpers.docs_summarizer.DocsSummarizer.generate_response", side_effect=Exception("summarizer error"), ), ): conversation_id = suid.get_suid() response = pytest.client.post( - "/v1/query", + endpoint, json={"conversation_id": conversation_id, "query": "test query"}, ) assert response.status_code == DEFAULT_STATUS_CODE @@ -122,12 +137,13 @@ def test_post_question_on_generic_response_type_summarize_error(_setup): assert response.json() == expected_json +@pytest.mark.parametrize("endpoint", ("/v1/query", "/v1/streaming_query")) @patch( "ols.app.endpoints.ols.config.ols_config.query_validation_method", constants.QueryValidationMethod.LLM, ) -def test_post_question_that_is_not_validated(_setup): - """Check the REST API /v1/query with POST HTTP method for question that is not validated.""" +def test_post_question_that_is_not_validated(_setup, endpoint): + """Check the REST API query endpoints for question that is not validated.""" # let's pretend the question can not be validated with patch( "ols.app.endpoints.ols.QuestionValidator.validate_question", @@ -135,7 +151,7 @@ def test_post_question_that_is_not_validated(_setup): ): conversation_id = suid.get_suid() response = pytest.client.post( - "/v1/query", + endpoint, json={"conversation_id": conversation_id, "query": "test query"}, ) @@ -150,11 +166,12 @@ def test_post_question_that_is_not_validated(_setup): assert response.json() == expected_details -def test_post_question_with_provider_but_not_model(_setup): +@pytest.mark.parametrize("endpoint", ("/v1/query", "/v1/streaming_query")) +def test_post_question_with_provider_but_not_model(_setup, endpoint): """Check how missing model is detected in request.""" conversation_id = suid.get_suid() response = pytest.client.post( - "/v1/query", + endpoint, json={ "conversation_id": conversation_id, "query": "test query", @@ -170,11 +187,12 @@ def test_post_question_with_provider_but_not_model(_setup): ) -def test_post_question_with_model_but_not_provider(_setup): +@pytest.mark.parametrize("endpoint", ("/v1/query", "/v1/streaming_query")) +def test_post_question_with_model_but_not_provider(_setup, endpoint): """Check how missing provider is detected in request.""" conversation_id = suid.get_suid() response = pytest.client.post( - "/v1/query", + endpoint, json={ "conversation_id": conversation_id, "query": "test query", @@ -190,12 +208,13 @@ def test_post_question_with_model_but_not_provider(_setup): ) -def test_unknown_provider_in_post(_setup): - """Check the REST API /v1/query with POST method when unknown provider is requested.""" +@pytest.mark.parametrize("endpoint", ("/v1/query", "/v1/streaming_query")) +def test_unknown_provider_in_post(_setup, endpoint): + """Check the REST API query endpoints with POST method when unknown provider is requested.""" # empty config - no providers config.llm_config.providers = {} response = pytest.client.post( - "/v1/query", + endpoint, json={ "query": "hello?", "provider": "some-provider", @@ -215,19 +234,20 @@ def test_unknown_provider_in_post(_setup): assert response.json() == expected_json +@pytest.mark.parametrize("endpoint", ("/v1/query", "/v1/streaming_query")) @patch( "ols.app.endpoints.ols.config.ols_config.query_validation_method", constants.QueryValidationMethod.LLM, ) -def test_unsupported_model_in_post(_setup): - """Check the REST API /v1/query with POST method when unsupported model is requested.""" +def test_unsupported_model_in_post(_setup, endpoint): + """Check the REST API query endpoints with POST method when unsupported model is requested.""" test_provider = "test-provider" provider_config = ProviderConfig() provider_config.models = {} # no models configured config.llm_config.providers = {test_provider: provider_config} response = pytest.client.post( - "/v1/query", + endpoint, json={ "query": "hello?", "provider": test_provider, @@ -246,8 +266,9 @@ def test_unsupported_model_in_post(_setup): assert response.json() == expected_json -def test_post_question_improper_conversation_id(_setup) -> None: - """Check the REST API /v1/query with POST HTTP method with improper conversation ID.""" +@pytest.mark.parametrize("endpoint", ("/v1/query", "/v1/streaming_query")) +def test_post_question_improper_conversation_id(_setup, endpoint) -> None: + """Check the REST API query endpoints with improper conversation ID.""" assert config.dev_config is not None config.dev_config.disable_auth = True answer = constants.SUBJECT_ALLOWED @@ -257,7 +278,7 @@ def test_post_question_improper_conversation_id(_setup) -> None: conversation_id = "not-correct-uuid" response = pytest.client.post( - "/v1/query", + endpoint, json={ "conversation_id": conversation_id, "query": "test query", @@ -274,41 +295,39 @@ def test_post_question_improper_conversation_id(_setup) -> None: assert response.json() == expected_details -def test_post_question_on_noyaml_response_type(_setup) -> None: - """Check the REST API /v1/query with POST HTTP method when call is success.""" - answer = constants.SUBJECT_ALLOWED - with patch( - "ols.app.endpoints.ols.QuestionValidator.validate_question", return_value=answer +@pytest.mark.parametrize("endpoint", ("/v1/query", "/v1/streaming_query")) +def test_post_question_on_noyaml_response_type(_setup, endpoint) -> None: + """Check the REST API query endpoints when call is success.""" + ml = mock_langchain_interface("test response") + with ( + patch( + "ols.src.query_helpers.docs_summarizer.LLMChain", + new=mock_llm_chain(None), + ), + patch( + "ols.src.query_helpers.query_helper.load_llm", + new=mock_llm_loader(ml()), + ), ): - ml = mock_langchain_interface("test response") - with ( - patch( - "ols.src.query_helpers.docs_summarizer.LLMChain", - new=mock_llm_chain(None), - ), - patch( - "ols.src.query_helpers.query_helper.load_llm", - new=mock_llm_loader(ml()), - ), - ): - conversation_id = suid.get_suid() - response = pytest.client.post( - "/v1/query", - json={ - "conversation_id": conversation_id, - "query": "test query", - }, - ) - print(response) - assert response.status_code == requests.codes.ok + conversation_id = suid.get_suid() + response = pytest.client.post( + endpoint, + json={ + "conversation_id": conversation_id, + "query": "test query", + }, + ) + print(response) + assert response.status_code == requests.codes.ok +@pytest.mark.parametrize("endpoint", ("/v1/query", "/v1/streaming_query")) @patch( "ols.app.endpoints.ols.config.ols_config.query_validation_method", constants.QueryValidationMethod.KEYWORD, ) @patch("ols.app.endpoints.ols.QuestionValidator.validate_question") -def test_post_question_with_keyword(mock_llm_validation, _setup) -> None: +def test_post_question_with_keyword(mock_llm_validation, _setup, endpoint) -> None: """Check the REST API /v1/query with keyword validation.""" query = "What is Openshift ?" @@ -325,17 +344,26 @@ def test_post_question_with_keyword(mock_llm_validation, _setup) -> None: ): conversation_id = suid.get_suid() response = pytest.client.post( - "/v1/query", + endpoint, json={"conversation_id": conversation_id, "query": query}, ) assert response.status_code == requests.codes.ok + + if response.headers["content-type"] == "application/json": + # non-streaming responses return JSON + actual_response = response.json()["response"] + else: + # streaming_query returns bytes + actual_response = response.text + # Currently mock invoke passes same query as response text. - assert query in response.json()["response"] + assert query in actual_response assert mock_llm_validation.call_count == 0 -def test_post_query_with_query_filters_response_type(_setup) -> None: - """Check the REST API /v1/query with POST HTTP method with query filters.""" +@pytest.mark.parametrize("endpoint", ("/v1/query", "/v1/streaming_query")) +def test_post_query_with_query_filters_response_type(_setup, endpoint) -> None: + """Check the REST API query endpoints with query filters.""" answer = constants.SUBJECT_ALLOWED query_filters = [ @@ -350,7 +378,8 @@ def test_post_query_with_query_filters_response_type(_setup) -> None: config.ols_config.query_filters = query_filters with patch( - "ols.app.endpoints.ols.QuestionValidator.validate_question", return_value=answer + "ols.src.query_helpers.question_validator.QuestionValidator.validate_question", + return_value=answer, ): ml = mock_langchain_interface("test response") with ( @@ -365,88 +394,91 @@ def test_post_query_with_query_filters_response_type(_setup) -> None: ): conversation_id = suid.get_suid() response = pytest.client.post( - "/v1/query", + endpoint, json={ "conversation_id": conversation_id, "query": "test query with 9.25.33.67 will be replaced with redacted_ip", }, ) - print(response.json()) + assert response.status_code == requests.codes.ok + + if response.headers["content-type"] == "application/json": + # non-streaming responses return JSON + actual_response = response.json()["response"] + else: + # streaming_query returns bytes + actual_response = response.text + assert ( "test query with redacted_ip will be replaced with redacted_ip" - in response.json()["response"] + in actual_response ) -def test_post_query_for_conversation_history(_setup) -> None: - """Check the REST API /v1/query with same conversation_id for conversation history.""" - answer = constants.SUBJECT_ALLOWED - with patch( - "ols.app.endpoints.ols.QuestionValidator.validate_question", return_value=answer - ): +@pytest.mark.parametrize("endpoint", ("/v1/query", "/v1/streaming_query")) +def test_post_query_for_conversation_history(_setup, endpoint) -> None: + """Check the REST API query endpoints with same conversation_id for conversation history.""" + # we need to import it here because these modules triggers config + # load too -> causes exception in auth module because of missing config + # values + from ols.app.endpoints.ols import retrieve_previous_input # pylint: disable=C0415 + from ols.app.models.models import CacheEntry # pylint: disable=C0415 - ml = mock_langchain_interface("test response") - with ( - patch( - "ols.src.query_helpers.docs_summarizer.LLMChain", - new=mock_llm_chain(None), - ), - patch( - "ols.src.query_helpers.docs_summarizer.LLMChain.invoke", - return_value={"text": "some response"}, - ) as invoke, - patch( - "ols.src.query_helpers.query_helper.load_llm", - new=mock_llm_loader(ml()), - ), - patch( - "ols.app.metrics.token_counter.TokenMetricUpdater.__enter__", - ) as token_counter, - ): - conversation_id = suid.get_suid() - response = pytest.client.post( - "/v1/query", - json={ - "conversation_id": conversation_id, - "query": "Query1", - }, - ) - assert response.status_code == requests.codes.ok - invoke.assert_called_once_with( - input={ - "query": "Query1", - }, - config={"callbacks": [token_counter.return_value]}, - ) - invoke.reset_mock() + actual_returned_history = None - response = pytest.client.post( - "/v1/query", - json={ - "conversation_id": conversation_id, - "query": "Query2", - }, - ) - chat_history_expected = [ - HumanMessage(content="Query1"), - AIMessage(content=response.json()["response"]), - ] - invoke.assert_called_once_with( - input={ - "query": "Query2", - "chat_history": chat_history_expected, - }, - config={"callbacks": [token_counter.return_value]}, - ) + def capture_return_value(*args, **kwargs): + nonlocal actual_returned_history + actual_returned_history = retrieve_previous_input(*args, **kwargs) + return actual_returned_history + ml = mock_langchain_interface("test response") + with ( + patch( + "ols.src.query_helpers.docs_summarizer.LLMChain", + new=mock_llm_chain(None), + ), + patch( + "ols.src.query_helpers.query_helper.load_llm", + new=mock_llm_loader(ml()), + ), + patch( + "ols.app.endpoints.ols.retrieve_previous_input", + side_effect=capture_return_value, + ), + ): + conversation_id = suid.get_suid() + response = pytest.client.post( + endpoint, + json={ + "conversation_id": conversation_id, + "query": "Query1", + }, + ) + assert response.status_code == requests.codes.ok + assert actual_returned_history == [] # pylint: disable=C1803 + + response = pytest.client.post( + endpoint, + json={ + "conversation_id": conversation_id, + "query": "Query2", + }, + ) + assert response.status_code == requests.codes.ok + chat_history_expected = [ + CacheEntry(query="Query1", response="Query1", attachments=[]) + ] + assert actual_returned_history == chat_history_expected + +@pytest.mark.parametrize("endpoint", ("/v1/query", "/v1/streaming_query")) @patch( "ols.app.endpoints.ols.config.ols_config.query_validation_method", constants.QueryValidationMethod.LLM, ) -def test_post_question_without_attachments(_setup) -> None: - """Check the REST API /v1/query with POST HTTP method without attachments.""" +def test_post_question_without_attachments(_setup, endpoint) -> None: + """Check the REST API query endpoints without attachments.""" answer = constants.SUBJECT_ALLOWED query_passed = None @@ -473,7 +505,7 @@ def validate_question(_conversation_id, query): ): conversation_id = suid.get_suid() response = pytest.client.post( - "/v1/query", + endpoint, json={ "conversation_id": conversation_id, "query": "test query", @@ -483,13 +515,14 @@ def validate_question(_conversation_id, query): assert query_passed == "test query" +@pytest.mark.parametrize("endpoint", ("/v1/query", "/v1/streaming_query")) @patch( "ols.app.endpoints.ols.config.ols_config.query_validation_method", constants.QueryValidationMethod.LLM, ) @pytest.mark.attachment -def test_post_question_with_empty_list_of_attachments(_setup) -> None: - """Check the REST API /v1/query with POST HTTP method with empty list of attachments.""" +def test_post_question_with_empty_list_of_attachments(_setup, endpoint) -> None: + """Check the REST API query endpoints with empty list of attachments.""" answer = constants.SUBJECT_ALLOWED query_passed = None @@ -516,7 +549,7 @@ def validate_question(_conversation_id, query): ): conversation_id = suid.get_suid() response = pytest.client.post( - "/v1/query", + endpoint, json={ "conversation_id": conversation_id, "query": "test query", @@ -527,13 +560,14 @@ def validate_question(_conversation_id, query): assert query_passed == "test query" +@pytest.mark.parametrize("endpoint", ("/v1/query", "/v1/streaming_query")) @pytest.mark.attachment @patch( "ols.app.endpoints.ols.config.ols_config.query_validation_method", constants.QueryValidationMethod.LLM, ) -def test_post_question_with_one_plaintext_attachment(_setup) -> None: - """Check the REST API /v1/query with POST HTTP method with one attachment.""" +def test_post_question_with_one_plaintext_attachment(_setup, endpoint) -> None: + """Check the REST API query endpoints with one attachment.""" answer = constants.SUBJECT_ALLOWED query_passed = None @@ -560,7 +594,7 @@ def validate_question(_conversation_id, query): ): conversation_id = suid.get_suid() response = pytest.client.post( - "/v1/query", + endpoint, json={ "conversation_id": conversation_id, "query": "test query", @@ -584,13 +618,14 @@ def validate_question(_conversation_id, query): assert query_passed == expected +@pytest.mark.parametrize("endpoint", ("/v1/query", "/v1/streaming_query")) @pytest.mark.attachment @patch( "ols.app.endpoints.ols.config.ols_config.query_validation_method", constants.QueryValidationMethod.LLM, ) -def test_post_question_with_one_yaml_attachment(_setup) -> None: - """Check the REST API /v1/query with POST HTTP method with YAML attachment.""" +def test_post_question_with_one_yaml_attachment(_setup, endpoint) -> None: + """Check the REST API query endpoints with YAML attachment.""" answer = constants.SUBJECT_ALLOWED query_passed = None @@ -622,7 +657,7 @@ def validate_question(_conversation_id, query): name: private-reg """ response = pytest.client.post( - "/v1/query", + endpoint, json={ "conversation_id": conversation_id, "query": "test query", @@ -650,13 +685,14 @@ def validate_question(_conversation_id, query): assert query_passed == expected +@pytest.mark.parametrize("endpoint", ("/v1/query", "/v1/streaming_query")) @pytest.mark.attachment @patch( "ols.app.endpoints.ols.config.ols_config.query_validation_method", constants.QueryValidationMethod.LLM, ) -def test_post_question_with_two_yaml_attachments(_setup) -> None: - """Check the REST API /v1/query with POST HTTP method with two YAML attachments.""" +def test_post_question_with_two_yaml_attachments(_setup, endpoint) -> None: + """Check the REST API query endpoints with two YAML attachments.""" answer = constants.SUBJECT_ALLOWED query_passed = None @@ -693,7 +729,7 @@ def validate_question(_conversation_id, query): name: foobar-deployment """ response = pytest.client.post( - "/v1/query", + endpoint, json={ "conversation_id": conversation_id, "query": "test query", @@ -736,13 +772,14 @@ def validate_question(_conversation_id, query): assert query_passed == expected +@pytest.mark.parametrize("endpoint", ("/v1/query", "/v1/streaming_query")) @pytest.mark.attachment @patch( "ols.app.endpoints.ols.config.ols_config.query_validation_method", constants.QueryValidationMethod.LLM, ) -def test_post_question_with_one_yaml_without_kind_attachment(_setup) -> None: - """Check the REST API /v1/query with POST HTTP method with one YAML without kind attachment.""" +def test_post_question_with_one_yaml_without_kind_attachment(_setup, endpoint) -> None: + """Check the REST API query endpoints with one YAML without kind attachment.""" answer = constants.SUBJECT_ALLOWED query_passed = None @@ -773,7 +810,7 @@ def validate_question(_conversation_id, query): name: private-reg """ response = pytest.client.post( - "/v1/query", + endpoint, json={ "conversation_id": conversation_id, "query": "test query", @@ -800,13 +837,14 @@ def validate_question(_conversation_id, query): assert query_passed == expected +@pytest.mark.parametrize("endpoint", ("/v1/query", "/v1/streaming_query")) @pytest.mark.attachment @patch( "ols.app.endpoints.ols.config.ols_config.query_validation_method", constants.QueryValidationMethod.LLM, ) -def test_post_question_with_one_yaml_without_name_attachment(_setup) -> None: - """Check the REST API /v1/query with POST HTTP method with one YAML without name attachment.""" +def test_post_question_with_one_yaml_without_name_attachment(_setup, endpoint) -> None: + """Check the REST API query endpoints with one YAML without name attachment.""" answer = constants.SUBJECT_ALLOWED query_passed = None @@ -838,7 +876,7 @@ def validate_question(_conversation_id, query): foo: bar """ response = pytest.client.post( - "/v1/query", + endpoint, json={ "conversation_id": conversation_id, "query": "test query", @@ -866,13 +904,14 @@ def validate_question(_conversation_id, query): assert query_passed == expected +@pytest.mark.parametrize("endpoint", ("/v1/query", "/v1/streaming_query")) @pytest.mark.attachment @patch( "ols.app.endpoints.ols.config.ols_config.query_validation_method", constants.QueryValidationMethod.LLM, ) -def test_post_question_with_one_invalid_yaml_attachment(_setup) -> None: - """Check the REST API /v1/query with POST HTTP method with one invalid YAML attachment.""" +def test_post_question_with_one_invalid_yaml_attachment(_setup, endpoint) -> None: + """Check the REST API query endpoints with one invalid YAML attachment.""" answer = constants.SUBJECT_ALLOWED query_passed = None @@ -904,7 +943,7 @@ def validate_question(_conversation_id, query): name: private-reg """ response = pytest.client.post( - "/v1/query", + endpoint, json={ "conversation_id": conversation_id, "query": "test query", @@ -932,9 +971,10 @@ def validate_question(_conversation_id, query): assert query_passed == expected +@pytest.mark.parametrize("endpoint", ("/v1/query", "/v1/streaming_query")) @pytest.mark.attachment -def test_post_question_with_large_attachment(_setup) -> None: - """Check the REST API /v1/query with POST HTTP method with large attachment.""" +def test_post_question_with_large_attachment(_setup, endpoint) -> None: + """Check the REST API query endpoints with large attachment.""" answer = constants.SUBJECT_ALLOWED def validate_question(_conversation_id, _query): @@ -969,7 +1009,7 @@ def validate_question(_conversation_id, _query): conversation_id = suid.get_suid() response = pytest.client.post( - "/v1/query", + endpoint, json={ "conversation_id": conversation_id, "query": "test query", @@ -982,24 +1022,37 @@ def validate_question(_conversation_id, _query): ], }, ) - # error should be returned because of very large input - assert response.status_code == requests.codes.request_entity_too_large - - -def test_post_too_long_query(_setup): - """Check the REST API /v1/query with POST HTTP method for query that is too long.""" + if response.headers["content-type"] == "application/json": + # non-streaming responses return JSON + assert response.status_code == requests.codes.request_entity_too_large + else: + # streaming_query returns bytes + error_response = response.text + assert "Prompt is too long" in error_response + assert "exceeds LLM available context window limit" in error_response + + +@pytest.mark.parametrize("endpoint", ("/v1/query", "/v1/streaming_query")) +def test_post_too_long_query(_setup, endpoint): + """Check the REST API query endpoints for query that is too long.""" query = "test query" * 1000 conversation_id = suid.get_suid() response = pytest.client.post( - "/v1/query", + endpoint, json={"conversation_id": conversation_id, "query": query}, ) - # error should be returned - assert response.status_code == requests.codes.request_entity_too_large - error_response = response.json()["detail"] - assert error_response["response"] == "Prompt is too long" - assert "exceeds" in error_response["cause"] + if response.headers["content-type"] == "application/json": + # non-streaming responses return JSON + assert response.status_code == requests.codes.request_entity_too_large + error_response = response.json()["detail"] + assert error_response["response"] == "Prompt is too long" + assert "exceeds" in error_response["cause"] + else: + # streaming_query returns bytes + error_response = response.text + assert "Prompt is too long" in error_response + assert "exceeds LLM available context window limit" in error_response def _post_with_system_prompt_override(_setup, caplog, query, system_prompt): diff --git a/tests/mock_classes/mock_llm_loader.py b/tests/mock_classes/mock_llm_loader.py index c32506c9..5b6f8b9e 100644 --- a/tests/mock_classes/mock_llm_loader.py +++ b/tests/mock_classes/mock_llm_loader.py @@ -14,6 +14,11 @@ def __init__(self, llm=None): llm.model = "mock_model" self.llm = llm + async def astream(self, llm_input, **kwargs): + """Return query result.""" + # yield input prompt/user query + yield llm_input[1].content + def mock_llm_loader(llm=None, expected_params=None): """Construct mock for load_llm.""" diff --git a/tests/unit/app/endpoints/test_ols.py b/tests/unit/app/endpoints/test_ols.py index 722f3d83..cea64621 100644 --- a/tests/unit/app/endpoints/test_ols.py +++ b/tests/unit/app/endpoints/test_ols.py @@ -618,7 +618,7 @@ def test_attachments_redact_on_redact_error(): constants.QueryValidationMethod.LLM, ) @patch("ols.src.query_helpers.question_validator.QuestionValidator.validate_question") -@patch("ols.src.query_helpers.docs_summarizer.DocsSummarizer.summarize") +@patch("ols.src.query_helpers.docs_summarizer.DocsSummarizer.create_response") @patch("ols.config.conversation_cache.get") def test_conversation_request( mock_conversation_cache_get, @@ -665,6 +665,38 @@ def test_conversation_request( assert len(response.conversation_id) == 0 +@pytest.mark.usefixtures("_load_config") +@patch("ols.src.query_helpers.question_validator.QuestionValidator.validate_question") +@patch("ols.src.query_helpers.docs_summarizer.DocsSummarizer.create_response") +@patch("ols.config.conversation_cache.get") +def test_conversation_request_dedup_ref_docs( + mock_conversation_cache_get, + mock_summarize, + mock_validate_question, + auth, +): + """Test deduplication of referenced docs.""" + mock_rag_chunk = [ + RagChunk("text1", "url-b", "title-b"), + RagChunk("text2", "url-b", "title-b"), # duplicate doc + RagChunk("text3", "url-a", "title-a"), + ] + mock_validate_question.return_value = True + mock_summarize.return_value = SummarizerResponse( + response="some response", + rag_chunks=mock_rag_chunk, + history_truncated=False, + ) + llm_request = LLMRequest(query="some query") + response = ols.conversation_request(llm_request, auth) + + assert len(response.referenced_documents) == 2 + assert response.referenced_documents[0].docs_url == "url-b" + assert response.referenced_documents[0].title == "title-b" + assert response.referenced_documents[1].docs_url == "url-a" + assert response.referenced_documents[1].title == "title-a" + + @pytest.mark.usefixtures("_load_config") @patch( "ols.app.endpoints.ols.config.ols_config.query_validation_method", @@ -718,7 +750,7 @@ def test_question_validation_in_conversation_start(auth): "ols.app.endpoints.ols.validate_question", new=Mock(return_value=constants.SUBJECT_REJECTED), ) -@patch("ols.src.query_helpers.docs_summarizer.DocsSummarizer.summarize") +@patch("ols.src.query_helpers.docs_summarizer.DocsSummarizer.create_response") def test_no_question_validation_in_follow_up_conversation(mock_summarize, auth): """Test if question validation is skipped in follow-up conversation.""" # note the `validate_question` is patched to always return as `SUBJECT_REJECTED` @@ -752,7 +784,7 @@ def test_conversation_request_invalid_subject(mock_validate, auth): @pytest.mark.usefixtures("_load_config") -@patch("ols.src.query_helpers.docs_summarizer.DocsSummarizer.summarize") +@patch("ols.src.query_helpers.docs_summarizer.DocsSummarizer.create_response") def test_generate_response_valid_subject(mock_summarize): """Test how generate_response function checks validation results.""" # mock the DocsSummarizer @@ -782,7 +814,7 @@ def test_generate_response_valid_subject(mock_summarize): @pytest.mark.usefixtures("_load_config") -@patch("ols.src.query_helpers.docs_summarizer.DocsSummarizer.summarize") +@patch("ols.src.query_helpers.docs_summarizer.DocsSummarizer.create_response") def test_generate_response_on_summarizer_error(mock_summarize): """Test how generate_response function checks validation results.""" # mock the DocsSummarizer diff --git a/tests/unit/app/endpoints/test_streaming_ols.py b/tests/unit/app/endpoints/test_streaming_ols.py new file mode 100644 index 00000000..1bb481ec --- /dev/null +++ b/tests/unit/app/endpoints/test_streaming_ols.py @@ -0,0 +1,131 @@ +"""Unit tests for streaming_ols.py.""" + +import json + +import pytest + +from ols import config, constants +from ols.app.endpoints.streaming_ols import ( + build_referenced_docs, + build_yield_item, + generic_llm_error, + invalid_response_generator, + prompt_too_long_error, + stream_end_event, + stream_start_event, +) +from ols.app.models.models import RagChunk +from ols.utils import suid + +conversation_id = suid.get_suid() + + +async def drain_generator(generator) -> str: + """Drain the async generator and return the result.""" + result = "" + async for item in generator: + result += item + return result + + +@pytest.fixture(scope="function") +def _load_config(): + """Load config before unit tests.""" + config.reload_from_yaml_file("tests/config/test_app_endpoints.yaml") + + +@pytest.mark.asyncio +@pytest.mark.usefixtures("_load_config") +async def test_invalid_response_generator(): + """Test invalid_response_generator.""" + generator = invalid_response_generator() + + response = await drain_generator(generator) + + assert response == constants.INVALID_QUERY_RESP + + +def test_build_yield_item(): + """Test build_yield_item.""" + assert build_yield_item("bla", 0, constants.MEDIA_TYPE_TEXT) == "bla" + assert ( + build_yield_item("bla", 1, constants.MEDIA_TYPE_JSON) + == '{"event": "token", "data": {"id": 1, "token": "bla"}}' + ) + + +def test_prompt_too_long_error(): + """Test prompt_too_long_error.""" + assert ( + prompt_too_long_error("error", constants.MEDIA_TYPE_TEXT) + == "Prompt is too long: error" + ) + + assert ( + prompt_too_long_error("error", constants.MEDIA_TYPE_JSON) + == '{"event": "error", "data": {"response": "Prompt is too long", "cause": "error"}}' + ) + + +def test_generic_llm_error(): + """Test generic_llm_error.""" + assert ( + generic_llm_error("error", constants.MEDIA_TYPE_TEXT) + == "Oops, something went wrong during LLM invocation: error" + ) + + assert ( + generic_llm_error("error", constants.MEDIA_TYPE_JSON) + == '{"event": "error", "data": {"response": "Oops, something went wrong during LLM invocation", "cause": "error"}}' # noqa: E501 + ) + + +def test_stream_start_event(): + """Test stream_start_event.""" + assert stream_start_event(conversation_id) == json.dumps( + { + "event": "start", + "data": { + "conversation_id": conversation_id, + }, + } + ) + + +def test_stream_end_event(): + """Test stream_end_event.""" + ref_docs = [{"doc_title": "title_1", "doc_url": "doc_url_1"}] + truncated = False + + assert ( + stream_end_event(ref_docs, truncated, constants.MEDIA_TYPE_TEXT) + == "\n\n---\n\ntitle_1: doc_url_1" + ) + + assert stream_end_event( + ref_docs, truncated, constants.MEDIA_TYPE_JSON + ) == json.dumps( + { + "event": "end", + "data": { + "referenced_documents": [ + {"doc_title": "title_1", "doc_url": "doc_url_1"} + ], + "truncated": truncated, + }, + } + ) + + +def test_build_referenced_docs(): + """Test build_referenced_docs.""" + rag_chunks = [ + RagChunk("bla", "url_1", "title_1"), + RagChunk("bla", "url_2", "title_2"), + RagChunk("bla", "url_1", "title_1"), # duplicate + ] + + assert build_referenced_docs(rag_chunks) == [ + {"doc_title": "title_1", "doc_url": "url_1"}, + {"doc_title": "title_2", "doc_url": "url_2"}, + ] diff --git a/tests/unit/app/models/test_models.py b/tests/unit/app/models/test_models.py index 64520f7e..79ba77ac 100644 --- a/tests/unit/app/models/test_models.py +++ b/tests/unit/app/models/test_models.py @@ -16,6 +16,7 @@ ReferencedDocument, StatusResponse, ) +from ols.constants import MEDIA_TYPE_JSON, MEDIA_TYPE_TEXT from ols.utils import suid @@ -33,6 +34,8 @@ def test_llm_request_required_inputs(): assert llm_request.conversation_id is None assert llm_request.provider is None assert llm_request.model is None + assert llm_request.attachments is None + assert llm_request.media_type == "text/plain" @staticmethod def test_llm_request_optional_inputs(): @@ -93,6 +96,22 @@ def test_llm_response(): assert llm_response.referenced_documents == referenced_documents assert not llm_response.truncated + @staticmethod + def test_media_type(): + """Test the media_type field of the LLMRequest model.""" + query = "irrelevant" + + media_type = MEDIA_TYPE_TEXT + llm_request = LLMRequest(query=query, media_type=media_type) + assert llm_request.media_type == media_type + + media_type = MEDIA_TYPE_JSON + llm_request = LLMRequest(query=query, media_type=media_type) + assert llm_request.media_type == media_type + + with pytest.raises(ValidationError, match="Invalid media type: 'unknown'"): + LLMRequest(query=query, media_type="unknown") + class TestStatusResponse: """Unit tests for the StatusResponse model.""" diff --git a/tests/unit/query_helpers/test_docs_summarizer.py b/tests/unit/query_helpers/test_docs_summarizer.py index 2ed8c6f7..5fc1a222 100644 --- a/tests/unit/query_helpers/test_docs_summarizer.py +++ b/tests/unit/query_helpers/test_docs_summarizer.py @@ -59,7 +59,7 @@ def test_summarize_empty_history(): question = "What's the ultimate question with answer 42?" rag_index = MockLlamaIndex() history = [] # empty history - summary = summarizer.summarize(conversation_id, question, rag_index, history) + summary = summarizer.create_response(question, rag_index, history) check_summary_result(summary, question) @@ -72,7 +72,7 @@ def test_summarize_no_history(): question = "What's the ultimate question with answer 42?" rag_index = MockLlamaIndex() # no history is passed into summarize() method - summary = summarizer.summarize(conversation_id, question, rag_index) + summary = summarizer.create_response(question, rag_index) check_summary_result(summary, question) @@ -91,7 +91,7 @@ def test_summarize_history_provided(): "ols.src.query_helpers.docs_summarizer.TokenHandler.limit_conversation_history", return_value=([], False), ) as token_handler: - summary1 = summarizer.summarize(conversation_id, question, rag_index, history) + summary1 = summarizer.create_response(question, rag_index, history) token_handler.assert_called_once_with(history, ANY, ANY) check_summary_result(summary1, question) @@ -100,7 +100,7 @@ def test_summarize_history_provided(): "ols.src.query_helpers.docs_summarizer.TokenHandler.limit_conversation_history", return_value=([], False), ) as token_handler: - summary2 = summarizer.summarize(conversation_id, question, rag_index) + summary2 = summarizer.create_response(question, rag_index) token_handler.assert_called_once_with([], ANY, ANY) check_summary_result(summary2, question) @@ -115,7 +115,7 @@ def test_summarize_truncation(): # too long history history = ["human: What is Kubernetes?"] * 10000 - summary = summarizer.summarize(conversation_id, question, rag_index, history) + summary = summarizer.create_response(question, rag_index, history) # truncation should be done assert summary.history_truncated @@ -128,7 +128,7 @@ def test_summarize_no_reference_content(): llm_loader=mock_llm_loader(mock_langchain_interface("test response")()) ) question = "What's the ultimate question with answer 42?" - summary = summarizer.summarize(conversation_id, question) + summary = summarizer.create_response(question) assert question in summary.response assert summary.rag_chunks == [] assert not summary.history_truncated @@ -154,3 +154,21 @@ def test_summarize_reranker(caplog): # Check captured log text to see if reranker was called. assert "reranker.rerank() is called with 1 result(s)." in caplog.text + + +@pytest.mark.asyncio +@patch("ols.src.query_helpers.docs_summarizer.LLMChain", new=mock_llm_chain(None)) +async def test_response_generator(): + """Test response generator method.""" + summarizer = DocsSummarizer( + llm_loader=mock_llm_loader(mock_langchain_interface("test response")()) + ) + question = "What's the ultimate question with answer 42?" + summary_gen = summarizer.generate_response(question) + generated_content = "" + + async for item in summary_gen: + if isinstance(item, str): + generated_content += item + + assert generated_content == question From b5e6c3997849e516e2ed2cd7c2566aa65cbefc72 Mon Sep 17 00:00:00 2001 From: Ondrej Metelka Date: Thu, 9 Jan 2025 17:35:22 +0100 Subject: [PATCH 2/4] Add e2e tests for streming endpoint --- ols/app/endpoints/streaming_ols.py | 4 + tests/e2e/test_streaming_query_endpoint.py | 567 +++++++++++++++++++++ 2 files changed, 571 insertions(+) create mode 100644 tests/e2e/test_streaming_query_endpoint.py diff --git a/ols/app/endpoints/streaming_ols.py b/ols/app/endpoints/streaming_ols.py index 58779850..444d3f3a 100644 --- a/ols/app/endpoints/streaming_ols.py +++ b/ols/app/endpoints/streaming_ols.py @@ -332,8 +332,12 @@ async def response_processing_wrapper( idx += 1 except PromptTooLongError as summarizer_error: yield prompt_too_long_error(summarizer_error, media_type) + return # stop execution after error + except Exception as summarizer_error: yield generic_llm_error(summarizer_error, media_type) + return # stop execution after error + timestamps["generate response"] = time.time() store_data( diff --git a/tests/e2e/test_streaming_query_endpoint.py b/tests/e2e/test_streaming_query_endpoint.py new file mode 100644 index 00000000..358e849d --- /dev/null +++ b/tests/e2e/test_streaming_query_endpoint.py @@ -0,0 +1,567 @@ +"""End to end tests for the REST API streming query endpoint.""" + +import json +import re + +import pytest +import requests + +from ols import constants +from ols.utils import suid +from tests.e2e.utils import cluster as cluster_utils +from tests.e2e.utils import metrics as metrics_utils +from tests.e2e.utils import response as response_utils +from tests.e2e.utils.decorators import retry + +from . import test_api + +endpoint = "/v1/streaming_query" + + +def parse_streaming_response_to_events(response: str) -> list[dict]: + """Parse streaming response to events.""" + return json.loads(f'[{response.replace("}{", "},{")}]') + + +def construct_response_from_streamed_events(events: dict) -> str: + """Construct response from streamed events.""" + response = "" + for event in events: + if event["event"] == "token": + response += event["data"]["token"] + return response + + +def test_invalid_question(): + """Check the endpoint POST method for invalid question.""" + with metrics_utils.RestAPICallCounterChecker(pytest.metrics_client, endpoint): + cid = suid.get_suid() + + response = pytest.client.post( + endpoint, + json={ + "conversation_id": cid, + "query": "how to make burger?", + "media_type": constants.MEDIA_TYPE_TEXT, + }, + timeout=test_api.LLM_REST_API_TIMEOUT, + ) + + assert response.status_code == requests.codes.ok + response_utils.check_content_type(response, constants.MEDIA_TYPE_TEXT) + + assert re.search( + r"(sorry|questions|assist)", + response.text, + re.IGNORECASE, + ) + + +def test_invalid_question_without_conversation_id(): + """Check the endpoint POST method for generating new conversation_id.""" + with metrics_utils.RestAPICallCounterChecker(pytest.metrics_client, endpoint): + response = pytest.client.post( + endpoint, + json={ + "query": "how to make burger?", + "media_type": constants.MEDIA_TYPE_JSON, + }, + timeout=test_api.LLM_REST_API_TIMEOUT, + ) + assert response.status_code == requests.codes.ok + response_utils.check_content_type(response, constants.MEDIA_TYPE_JSON) + events = parse_streaming_response_to_events(response.text) + + # new conversation ID should be generated + assert events[0]["event"] == "start" + assert events[0]["data"] + assert suid.check_suid(events[0]["data"]["conversation_id"]) + + +def test_query_call_without_payload(): + """Check the endpoint with POST HTTP method when no payload is provided.""" + with metrics_utils.RestAPICallCounterChecker( + pytest.metrics_client, + endpoint, + status_code=requests.codes.unprocessable_entity, + ): + response = pytest.client.post( + endpoint, + timeout=test_api.LLM_REST_API_TIMEOUT, + ) + assert response.status_code == requests.codes.unprocessable_entity + + response_utils.check_content_type(response, constants.MEDIA_TYPE_JSON) + # the actual response might differ when new Pydantic version + # will be used so let's do just primitive check + assert "missing" in response.text + + +def test_query_call_with_improper_payload(): + """Check the endpoint with POST HTTP method when improper payload is provided.""" + with metrics_utils.RestAPICallCounterChecker( + pytest.metrics_client, + endpoint, + status_code=requests.codes.unprocessable_entity, + ): + response = pytest.client.post( + endpoint, + json={"parameter": "this-is-unknown-parameter"}, + timeout=test_api.NON_LLM_REST_API_TIMEOUT, + ) + assert response.status_code == requests.codes.unprocessable_entity + + response_utils.check_content_type(response, constants.MEDIA_TYPE_JSON) + # the actual response might differ when new Pydantic version will be used + # so let's do just primitive check + assert "missing" in response.text + + +def test_valid_question_improper_conversation_id() -> None: + """Check the endpoint with POST HTTP method for improper conversation ID.""" + with metrics_utils.RestAPICallCounterChecker( + pytest.metrics_client, + endpoint, + status_code=requests.codes.internal_server_error, + ): + response = pytest.client.post( + endpoint, + json={"conversation_id": "not-uuid", "query": "what is kubernetes?"}, + timeout=test_api.LLM_REST_API_TIMEOUT, + ) + assert response.status_code == requests.codes.internal_server_error + + response_utils.check_content_type(response, constants.MEDIA_TYPE_JSON) + json_response = response.json() + expected_response = { + "detail": { + "response": "Error retrieving conversation history", + "cause": "Invalid conversation ID not-uuid", + } + } + assert json_response == expected_response + + +def test_too_long_question() -> None: + """Check the endpoint with too long question.""" + # let's make the query really large, larger that context window size + query = "what is kubernetes?" * 10000 + + with metrics_utils.RestAPICallCounterChecker( + pytest.metrics_client, + endpoint, + status_code=requests.codes.ok, + ): + cid = suid.get_suid() + response = pytest.client.post( + endpoint, + json={ + "conversation_id": cid, + "query": query, + "media_type": constants.MEDIA_TYPE_JSON, + }, + timeout=test_api.LLM_REST_API_TIMEOUT, + ) + assert response.status_code == requests.codes.ok + + response_utils.check_content_type(response, constants.MEDIA_TYPE_JSON) + + events = parse_streaming_response_to_events(response.text) + + assert len(events) == 2 + assert events[1]["event"] == "error" + assert events[1]["data"]["response"] == "Prompt is too long" + + +@pytest.mark.smoketest +@pytest.mark.rag +def test_valid_question() -> None: + """Check the endpoint with POST HTTP method for valid question and no yaml.""" + with metrics_utils.RestAPICallCounterChecker(pytest.metrics_client, endpoint): + cid = suid.get_suid() + response = pytest.client.post( + endpoint, + json={"conversation_id": cid, "query": "what is kubernetes?"}, + timeout=test_api.LLM_REST_API_TIMEOUT, + ) + assert response.status_code == requests.codes.ok + + response_utils.check_content_type(response, constants.MEDIA_TYPE_TEXT) + + assert "Kubernetes is" in response.text + assert re.search( + r"orchestration (tool|system|platform|engine)", + response.text, + re.IGNORECASE, + ) + + +@pytest.mark.rag +def test_ocp_docs_version_same_as_cluster_version() -> None: + """Check that the version of OCP docs matches the cluster we're on.""" + with metrics_utils.RestAPICallCounterChecker(pytest.metrics_client, endpoint): + cid = suid.get_suid() + response = pytest.client.post( + endpoint, + json={ + "conversation_id": cid, + "query": "welcome openshift container platform documentation", + "media_type": constants.MEDIA_TYPE_JSON, + }, + timeout=test_api.LLM_REST_API_TIMEOUT, + ) + assert response.status_code == requests.codes.ok + + response_utils.check_content_type(response, constants.MEDIA_TYPE_JSON) + major, minor = cluster_utils.get_cluster_version() + events = parse_streaming_response_to_events(response.text) + assert events[-1]["event"] == "end" + assert events[-1]["data"]["referenced_documents"] + assert ( + f"{major}.{minor}" + in events[-1]["data"]["referenced_documents"][0]["doc_title"] + ) + + +def test_valid_question_tokens_counter() -> None: + """Check how the tokens counter are updated accordingly.""" + model, provider = metrics_utils.get_enabled_model_and_provider( + pytest.metrics_client + ) + + with ( + metrics_utils.RestAPICallCounterChecker(pytest.metrics_client, endpoint), + metrics_utils.TokenCounterChecker(pytest.metrics_client, model, provider), + ): + response = pytest.client.post( + endpoint, + json={"query": "what is kubernetes?"}, + timeout=test_api.LLM_REST_API_TIMEOUT, + ) + assert response.status_code == requests.codes.ok + response_utils.check_content_type(response, constants.MEDIA_TYPE_TEXT) + + +def test_invalid_question_tokens_counter() -> None: + """Check how the tokens counter are updated accordingly.""" + model, provider = metrics_utils.get_enabled_model_and_provider( + pytest.metrics_client + ) + + with ( + metrics_utils.RestAPICallCounterChecker(pytest.metrics_client, endpoint), + metrics_utils.TokenCounterChecker(pytest.metrics_client, model, provider), + ): + response = pytest.client.post( + endpoint, + json={"query": "how to make burger?"}, + timeout=test_api.LLM_REST_API_TIMEOUT, + ) + assert response.status_code == requests.codes.ok + response_utils.check_content_type(response, constants.MEDIA_TYPE_TEXT) + + +def test_token_counters_for_query_call_without_payload() -> None: + """Check how the tokens counter are updated accordingly.""" + model, provider = metrics_utils.get_enabled_model_and_provider( + pytest.metrics_client + ) + + with ( + metrics_utils.RestAPICallCounterChecker( + pytest.metrics_client, + endpoint, + status_code=requests.codes.unprocessable_entity, + ), + metrics_utils.TokenCounterChecker( + pytest.metrics_client, + model, + provider, + expect_sent_change=False, + expect_received_change=False, + ), + ): + response = pytest.client.post( + endpoint, + timeout=test_api.LLM_REST_API_TIMEOUT, + ) + assert response.status_code == requests.codes.unprocessable_entity + response_utils.check_content_type(response, constants.MEDIA_TYPE_JSON) + + +def test_token_counters_for_query_call_with_improper_payload() -> None: + """Check how the tokens counter are updated accordingly.""" + model, provider = metrics_utils.get_enabled_model_and_provider( + pytest.metrics_client + ) + + with ( + metrics_utils.RestAPICallCounterChecker( + pytest.metrics_client, + endpoint, + status_code=requests.codes.unprocessable_entity, + ), + metrics_utils.TokenCounterChecker( + pytest.metrics_client, + model, + provider, + expect_sent_change=False, + expect_received_change=False, + ), + ): + response = pytest.client.post( + endpoint, + json={"parameter": "this-is-not-proper-question-my-friend"}, + timeout=test_api.LLM_REST_API_TIMEOUT, + ) + assert response.status_code == requests.codes.unprocessable_entity + response_utils.check_content_type(response, constants.MEDIA_TYPE_JSON) + + +@pytest.mark.rag +@retry(max_attempts=3, wait_between_runs=10) +def test_rag_question() -> None: + """Ensure responses include rag references.""" + with metrics_utils.RestAPICallCounterChecker(pytest.metrics_client, endpoint): + response = pytest.client.post( + endpoint, + json={ + "query": "what is openshift virtualization?", + "media_type": constants.MEDIA_TYPE_JSON, + }, + timeout=test_api.LLM_REST_API_TIMEOUT, + ) + assert response.status_code == requests.codes.ok + response_utils.check_content_type(response, constants.MEDIA_TYPE_JSON) + + events = parse_streaming_response_to_events(response.text) + + assert events[0]["event"] == "start" + assert events[0]["data"]["conversation_id"] + assert events[-1]["event"] == "end" + ref_docs = events[-1]["data"]["referenced_documents"] + assert ref_docs + assert "virt" in ref_docs[0]["doc_url"] + assert "https://" in ref_docs[0]["doc_url"] + + # ensure no duplicates in docs + docs_urls = [doc["doc_url"] for doc in ref_docs] + assert len(set(docs_urls)) == len(docs_urls) + + +@pytest.mark.cluster +def test_query_filter() -> None: + """Ensure responses does not include filtered words and redacted words are not logged.""" + with metrics_utils.RestAPICallCounterChecker(pytest.metrics_client, endpoint): + query = "what is foo in bar?" + response = pytest.client.post( + endpoint, + json={"query": query}, + timeout=test_api.LLM_REST_API_TIMEOUT, + ) + assert response.status_code == requests.codes.ok + response_utils.check_content_type(response, constants.MEDIA_TYPE_TEXT) + + # values to be filtered and replaced are defined in: + # tests/config/singleprovider.e2e.template.config.yaml + response_text = response.text.lower() + assert "openshift" in response_text + assert "deploy" in response_text + response_words = response_text.split() + assert "foo" not in response_words + assert "bar" not in response_words + + # Retrieve the pod name + ols_container_name = "lightspeed-service-api" + pod_name = cluster_utils.get_pod_by_prefix()[0] + + # Check if filtered words are redacted in the logs + container_log = cluster_utils.get_container_log(pod_name, ols_container_name) + + # Ensure redacted patterns do not appear in the logs + unwanted_patterns = ["foo ", "what is foo in bar?"] + for line in container_log.splitlines(): + # Only check lines that are not part of a query + if re.search(r'Body: \{"query":', line): + continue + # check that the pattern is indeed not found in logs + for pattern in unwanted_patterns: + assert pattern not in line.lower() + + # Ensure the intended redaction has occurred + assert "what is deployment in openshift?" in container_log + + +@retry(max_attempts=3, wait_between_runs=10) +def test_conversation_history() -> None: + """Ensure conversations include previous query history.""" + with metrics_utils.RestAPICallCounterChecker(pytest.metrics_client, endpoint): + response = pytest.client.post( + endpoint, + json={ + "query": "what is ingress in kubernetes?", + "media_type": constants.MEDIA_TYPE_JSON, + }, + timeout=test_api.LLM_REST_API_TIMEOUT, + ) + scenario_fail_msg = "First call to LLM without conversation history has failed" + assert response.status_code == requests.codes.ok, scenario_fail_msg + response_utils.check_content_type( + response, constants.MEDIA_TYPE_JSON, scenario_fail_msg + ) + + events = parse_streaming_response_to_events(response.text) + response_text = construct_response_from_streamed_events(events).lower() + + assert "ingress" in response_text, scenario_fail_msg + + # get the conversation id so we can reuse it for the follow up question + assert events[0]["event"] == "start" + cid = events[0]["data"]["conversation_id"] + response = pytest.client.post( + endpoint, + json={ + "conversation_id": cid, + "query": "what?", + "media_type": constants.MEDIA_TYPE_JSON, + }, + timeout=test_api.LLM_REST_API_TIMEOUT, + ) + + scenario_fail_msg = "Second call to LLM with conversation history has failed" + assert response.status_code == requests.codes.ok + response_utils.check_content_type( + response, constants.MEDIA_TYPE_JSON, scenario_fail_msg + ) + + events = parse_streaming_response_to_events(response.text) + response_text = construct_response_from_streamed_events(events).lower() + assert "ingress" in response_text, scenario_fail_msg + + +def test_query_with_provider_but_not_model() -> None: + """Check the endpoint with POST HTTP method for provider specified, but no model.""" + with metrics_utils.RestAPICallCounterChecker( + pytest.metrics_client, + endpoint, + status_code=requests.codes.unprocessable_entity, + ): + # just the provider is explicitly specified, but model selection is missing + response = pytest.client.post( + endpoint, + json={ + "conversation_id": "", + "query": "what is kubernetes?", + "provider": "bam", + }, + timeout=test_api.LLM_REST_API_TIMEOUT, + ) + assert response.status_code == requests.codes.unprocessable_entity + response_utils.check_content_type(response, constants.MEDIA_TYPE_JSON) + + json_response = response.json() + + # error thrown on Pydantic level + assert ( + json_response["detail"][0]["msg"] + == "Value error, LLM model must be specified when the provider is specified." + ) + + +def test_query_with_model_but_not_provider() -> None: + """Check the endpoint with POST HTTP method for model specified, but no provider.""" + with metrics_utils.RestAPICallCounterChecker( + pytest.metrics_client, + endpoint, + status_code=requests.codes.unprocessable_entity, + ): + # just model is explicitly specified, but provider selection is missing + response = pytest.client.post( + endpoint, + json={ + "conversation_id": "", + "query": "what is kubernetes?", + "model": "ibm/granite-13b-chat-v2", + }, + timeout=test_api.LLM_REST_API_TIMEOUT, + ) + assert response.status_code == requests.codes.unprocessable_entity + response_utils.check_content_type(response, constants.MEDIA_TYPE_JSON) + + json_response = response.json() + + assert ( + json_response["detail"][0]["msg"] + == "Value error, LLM provider must be specified when the model is specified." + ) + + +def test_query_with_unknown_provider() -> None: + """Check the endpoint with POST HTTP method for unknown provider specified.""" + # retrieve currently selected model + model, _ = metrics_utils.get_enabled_model_and_provider(pytest.metrics_client) + + with metrics_utils.RestAPICallCounterChecker( + pytest.metrics_client, + endpoint, + status_code=requests.codes.unprocessable_entity, + ): + # provider is unknown + response = pytest.client.post( + endpoint, + json={ + "conversation_id": "", + "query": "what is kubernetes?", + "provider": "foo", + "model": model, + }, + timeout=test_api.LLM_REST_API_TIMEOUT, + ) + assert response.status_code == requests.codes.unprocessable_entity + response_utils.check_content_type(response, constants.MEDIA_TYPE_JSON) + + json_response = response.json() + + # explicit response and cause check + assert ( + "detail" in json_response + ), "Improper response format: 'detail' node is missing" + assert "Unable to process this request" in json_response["detail"]["response"] + assert ( + "Provider 'foo' is not a valid provider." + in json_response["detail"]["cause"] + ) + + +def test_query_with_unknown_model() -> None: + """Check the endpoint with POST HTTP method for unknown model specified.""" + # retrieve currently selected provider + _, provider = metrics_utils.get_enabled_model_and_provider(pytest.metrics_client) + + with metrics_utils.RestAPICallCounterChecker( + pytest.metrics_client, + endpoint, + status_code=requests.codes.unprocessable_entity, + ): + # model is unknown + response = pytest.client.post( + endpoint, + json={ + "conversation_id": "", + "query": "what is kubernetes?", + "provider": provider, + "model": "bar", + }, + timeout=test_api.LLM_REST_API_TIMEOUT, + ) + assert response.status_code == requests.codes.unprocessable_entity + response_utils.check_content_type(response, constants.MEDIA_TYPE_JSON) + + json_response = response.json() + + # explicit response and cause check + assert ( + "detail" in json_response + ), "Improper response format: 'detail' node is missing" + assert "Unable to process this request" in json_response["detail"]["response"] + assert "Model 'bar' is not a valid model " in json_response["detail"]["cause"] From 75db7d38c175f3ea4ae99f6a7a81b69a2780a767 Mon Sep 17 00:00:00 2001 From: Ondrej Metelka Date: Fri, 10 Jan 2025 13:24:56 +0100 Subject: [PATCH 3/4] Fix readme part about streaming_query --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 420128cf..d915c55e 100644 --- a/README.md +++ b/README.md @@ -609,7 +609,7 @@ To send a request to the server you can use the following curl command: curl -X 'POST' 'http://127.0.0.1:8080/v1/query' -H 'accept: application/json' -H 'Content-Type: application/json' -d '{"query": "write a deployment yaml for the mongodb image"}' ``` -> You can use `/v1/streaming_query` endpoint (with same parameters) to get the streaming response (SSE/HTTP chunking). By default, it streams text, but you can also yield events as JSONs via additionl `"media_type": "text/plain"` parameter in the payload data. +> You can use the `/v1/streaming_query` (with the same parameters) to get the streaming response (SSE/HTTP chunking). By default, it streams text, but you can also yield events as JSONs via additional `"media_type": "text/plain"` parameter in the payload data. ### Swagger UI From d06de5ed4401f35bb09cf69d499524b8d0658351 Mon Sep 17 00:00:00 2001 From: onmete Date: Thu, 16 Jan 2025 11:05:40 +0100 Subject: [PATCH 4/4] Fix prompts path and arg after cherry pick --- ols/app/endpoints/streaming_ols.py | 3 ++- ols/src/query_helpers/docs_summarizer.py | 5 ++--- tests/unit/app/endpoints/test_streaming_ols.py | 3 ++- tests/unit/query_helpers/test_docs_summarizer.py | 4 ++-- 4 files changed, 8 insertions(+), 7 deletions(-) diff --git a/ols/app/endpoints/streaming_ols.py b/ols/app/endpoints/streaming_ols.py index 444d3f3a..26ac8f59 100644 --- a/ols/app/endpoints/streaming_ols.py +++ b/ols/app/endpoints/streaming_ols.py @@ -32,6 +32,7 @@ UnauthorizedResponse, ) from ols.constants import MEDIA_TYPE_TEXT +from ols.customize import prompts from ols.src.auth.auth import get_auth_dependency from ols.utils import errors_parsing from ols.utils.token_handler import PromptTooLongError @@ -119,7 +120,7 @@ async def invalid_response_generator() -> AsyncGenerator[str, None]: Yields: str: The response indicating invalid query. """ - yield constants.INVALID_QUERY_RESP + yield prompts.INVALID_QUERY_RESP def stream_start_event(conversation_id: str) -> str: diff --git a/ols/src/query_helpers/docs_summarizer.py b/ols/src/query_helpers/docs_summarizer.py index 044a6dd5..90fd81dc 100644 --- a/ols/src/query_helpers/docs_summarizer.py +++ b/ols/src/query_helpers/docs_summarizer.py @@ -11,9 +11,8 @@ from ols.app.metrics import TokenMetricUpdater from ols.app.models.models import RagChunk, SummarizerResponse from ols.constants import RAG_CONTENT_LIMIT, GenericLLMParameters -from ols.customize import reranker +from ols.customize import prompts, reranker from ols.src.prompts.prompt_generator import GeneratePrompt -from ols.src.prompts.prompts import QUERY_SYSTEM_INSTRUCTION from ols.src.query_helpers.query_helper import QueryHelper from ols.utils.token_handler import TokenHandler @@ -48,7 +47,7 @@ def _get_system_prompt(self) -> None: if config.ols_config.system_prompt is not None: self.system_prompt = config.ols_config.system_prompt else: - self.system_prompt = QUERY_SYSTEM_INSTRUCTION + self.system_prompt = prompts.QUERY_SYSTEM_INSTRUCTION logger.debug("System prompt: %s", self.system_prompt) def _prepare_prompt( diff --git a/tests/unit/app/endpoints/test_streaming_ols.py b/tests/unit/app/endpoints/test_streaming_ols.py index 1bb481ec..5d6ec8d1 100644 --- a/tests/unit/app/endpoints/test_streaming_ols.py +++ b/tests/unit/app/endpoints/test_streaming_ols.py @@ -15,6 +15,7 @@ stream_start_event, ) from ols.app.models.models import RagChunk +from ols.customize import prompts from ols.utils import suid conversation_id = suid.get_suid() @@ -42,7 +43,7 @@ async def test_invalid_response_generator(): response = await drain_generator(generator) - assert response == constants.INVALID_QUERY_RESP + assert response == prompts.INVALID_QUERY_RESP def test_build_yield_item(): diff --git a/tests/unit/query_helpers/test_docs_summarizer.py b/tests/unit/query_helpers/test_docs_summarizer.py index 5fc1a222..7aeb9b21 100644 --- a/tests/unit/query_helpers/test_docs_summarizer.py +++ b/tests/unit/query_helpers/test_docs_summarizer.py @@ -148,8 +148,8 @@ def test_summarize_reranker(caplog): summarizer = DocsSummarizer(llm_loader=mock_llm_loader(None)) question = "What's the ultimate question with answer 42?" rag_index = MockLlamaIndex() - # no history is passed into summarize() method - summary = summarizer.summarize(conversation_id, question, rag_index) + # no history is passed into create_response() method + summary = summarizer.create_response(question, rag_index) check_summary_result(summary, question) # Check captured log text to see if reranker was called.