diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml new file mode 100644 index 0000000..c59971b --- /dev/null +++ b/.github/workflows/tests.yml @@ -0,0 +1,41 @@ +name: Tests + +on: + pull_request: + branches: + - main + types: + - opened + - synchronize + - reopened + +jobs: + tests: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + + # Set up Python environment + - name: Set up Python + uses: actions/setup-python@v2 + with: + python-version: "3.x" + + # Install Poetry + - name: Install Poetry + run: | + curl -sSL https://install.python-poetry.org | python3 - + # Install lib and dev dependencies + - name: Install llmstudio-core + working-directory: ./libs/core + run: | + poetry install + POETRY_ENV=$(poetry env info --path) + echo $POETRY_ENV + echo "POETRY_ENV=$POETRY_ENV" >> $GITHUB_ENV + + - name: Run unit tests + run: | + echo ${{ env.POETRY_ENV }} + source ${{ env.POETRY_ENV }}/bin/activate + poetry run pytest libs/core diff --git a/.github/workflows/upload-pypi-dev.yml b/.github/workflows/upload-pypi-dev.yml index b13bf40..85cf886 100644 --- a/.github/workflows/upload-pypi-dev.yml +++ b/.github/workflows/upload-pypi-dev.yml @@ -1,4 +1,4 @@ -name: PyPI prerelease and build/push Docker image. +name: PyPI prerelease any module. on: workflow_dispatch: diff --git a/.gitignore b/.gitignore index 1901586..d80305e 100644 --- a/.gitignore +++ b/.gitignore @@ -56,6 +56,7 @@ env3 .env* .env*.local .venv* +*venv* env*/ venv*/ ENV/ @@ -66,6 +67,7 @@ venv.bak/ config.yaml bun.lockb + # Jupyter Notebook .ipynb_checkpoints @@ -76,4 +78,4 @@ bun.lockb llmstudio/llm_engine/logs/execution_logs.jsonl *.db .prettierignore -db \ No newline at end of file +db diff --git a/Makefile b/Makefile index 5a43a3e..6c95760 100644 --- a/Makefile +++ b/Makefile @@ -1,2 +1,5 @@ format: pre-commit run --all-files + +unit-tests: + pytest libs/core/tests/unit_tests diff --git a/libs/core/llmstudio_core/providers/azure.py b/libs/core/llmstudio_core/providers/azure.py index 2dbd730..f558f9d 100644 --- a/libs/core/llmstudio_core/providers/azure.py +++ b/libs/core/llmstudio_core/providers/azure.py @@ -62,7 +62,25 @@ async def agenerate_client(self, request: ChatRequest) -> Any: return self.generate_client(request=request) def generate_client(self, request: ChatRequest) -> Any: - """Generate an AzureOpenAI client""" + """ + Generates an AzureOpenAI client for processing a chat request. + + This method prepares and configures the arguments required to create a client + request to AzureOpenAI's chat completions API. It determines model-specific + configurations (e.g., whether tools or functions are enabled) and combines + these with the base arguments for the API call. + + Args: + request (ChatRequest): The chat request object containing the model, + parameters, and other necessary details. + + Returns: + Any: The result of the chat completions API call. + + Raises: + ProviderError: If there is an issue with the API connection or an error + returned from the API. + """ self.is_llama = "llama" in request.model.lower() self.is_openai = "gpt" in request.model.lower() @@ -72,7 +90,6 @@ def generate_client(self, request: ChatRequest) -> Any: try: messages = self.prepare_messages(request) - # Prepare the optional tool-related arguments tool_args = {} if not self.is_llama and self.has_tools and self.is_openai: tool_args = { @@ -80,7 +97,6 @@ def generate_client(self, request: ChatRequest) -> Any: "tool_choice": "auto" if request.parameters.get("tools") else None, } - # Prepare the optional function-related arguments function_args = {} if not self.is_llama and self.has_functions and self.is_openai: function_args = { @@ -90,14 +106,12 @@ def generate_client(self, request: ChatRequest) -> Any: else None, } - # Prepare the base arguments base_args = { "model": request.model, "messages": messages, "stream": True, } - # Combine all arguments combined_args = { **base_args, **tool_args, @@ -116,13 +130,13 @@ def prepare_messages(self, request: ChatRequest): if self.is_llama and (self.has_tools or self.has_functions): user_message = self.convert_to_openai_format(request.chat_input) content = "<|begin_of_text|>" - content = self.add_system_message( + content = self.build_llama_system_message( user_message, content, request.parameters.get("tools"), request.parameters.get("functions"), ) - content = self.add_conversation(user_message, content) + content = self.build_llama_conversation(user_message, content) return [{"role": "user", "content": content}] else: return ( @@ -139,6 +153,20 @@ async def aparse_response( yield chunk def parse_response(self, response: AsyncGenerator, **kwargs) -> Any: + """ + Processes a generator response and yields processed chunks. + + If `is_llama` is True and tools or functions are enabled, it processes the response + using `handle_tool_response`. Otherwise, it processes each chunk and yields only those + containing "choices". + + Args: + response (Generator): The response generator to process. + **kwargs: Additional arguments for tool handling. + + Yields: + Any: Processed response chunks. + """ if self.is_llama and (self.has_tools or self.has_functions): for chunk in self.handle_tool_response(response, **kwargs): if chunk: @@ -388,9 +416,25 @@ def convert_to_openai_format(self, message: Union[str, list]) -> list: return [{"role": "user", "content": message}] return message - def add_system_message( + def build_llama_system_message( self, openai_message: list, llama_message: str, tools: list, functions: list ) -> str: + """ + Builds a complete system message for Llama based on OpenAI's message, tools, and functions. + + If a system message is present in the OpenAI message, it is included in the result. + Otherwise, a default system message is used. Additional tool and function instructions + are appended if provided. + + Args: + openai_message (list): List of OpenAI messages. + llama_message (str): The message to prepend to the system message. + tools (list): List of tools to include in the system message. + functions (list): List of functions to include in the system message. + + Returns: + str: The formatted system message combined with Llama message. + """ system_message = "" system_message_found = False for message in openai_message: @@ -407,15 +451,31 @@ def add_system_message( """ if tools: - system_message = system_message + self.add_tool_instructions(tools) + system_message = system_message + self.build_tool_instructions(tools) if functions: - system_message = system_message + self.add_function_instructions(functions) + system_message = system_message + self.build_function_instructions( + functions + ) end_tag = "\n<|eot_id|>" return llama_message + system_message + end_tag - def add_tool_instructions(self, tools: list) -> str: + def build_tool_instructions(self, tools: list) -> str: + """ + Builds a detailed instructional prompt for tools available to the assistant. + + This function generates a message describing the available tools, focusing on tools + of type "function." It explains to the LLM how to use each tool and provides an example of the + correct response format for function calls. + + Args: + tools (list): A list of tool dictionaries, where each dictionary contains tool + details such as type, function name, description, and parameters. + + Returns: + str: A formatted string detailing the tool instructions and usage examples. + """ tool_prompt = """ You have access to the following tools: """ @@ -449,7 +509,21 @@ def add_tool_instructions(self, tools: list) -> str: return tool_prompt - def add_function_instructions(self, functions: list) -> str: + def build_function_instructions(self, functions: list) -> str: + """ + Builds a detailed instructional prompt for available functions. + + This method creates a message describing the functions accessible to the assistant. + It includes the function name, description, and required parameters, along with + specific guidelines for calling functions. + + Args: + functions (list): A list of function dictionaries, each containing details such as + name, description, and parameters. + + Returns: + str: A formatted string with instructions on using the provided functions. + """ function_prompt = """ You have access to the following functions: """ @@ -479,35 +553,60 @@ def add_function_instructions(self, functions: list) -> str: """ return function_prompt - def add_conversation(self, openai_message: list, llama_message: str) -> str: + def build_llama_conversation(self, openai_message: list, llama_message: str) -> str: + """ + Appends the OpenAI message to the Llama message while formatting OpenAI messages. + + This function iterates through a list of OpenAI messages and formats them for inclusion + in a Llama message. It handles user messages that might include nested content (lists of + messages) by safely evaluating the content. System messages are skipped. + + Args: + openai_message (list): A list of dictionaries representing the OpenAI messages. Each + dictionary should have "role" and "content" keys. + llama_message (str): The initial Llama message to which the conversation is appended. + + Returns: + str: The Llama message with the conversation appended. + """ conversation_parts = [] for message in openai_message: if message["role"] == "system": continue elif message["role"] == "user" and isinstance(message["content"], str): try: - # Attempt to safely evaluate the string to a Python object content_as_list = ast.literal_eval(message["content"]) if isinstance(content_as_list, list): - # If the content is a list, process each nested message for nested_message in content_as_list: conversation_parts.append( self.format_message(nested_message) ) else: - # If the content is not a list, append it directly conversation_parts.append(self.format_message(message)) except (ValueError, SyntaxError): - # If evaluation fails or content is not a list/dict string, append the message directly conversation_parts.append(self.format_message(message)) else: - # For all other messages, use the existing formatting logic conversation_parts.append(self.format_message(message)) return llama_message + "".join(conversation_parts) def format_message(self, message: dict) -> str: - """Format a single message for the conversation.""" + """ + Formats a single message dictionary into a structured string for a conversation. + + The formatting depends on the content of the message, such as tool calls, + function calls, or simple user/assistant messages. Each type of message + is formatted with specific headers and tags. + + Args: + message (dict): A dictionary containing message details. Expected keys + include "role", "content", and optionally "tool_calls", + "tool_call_id", or "function_call". + + Returns: + str: A formatted string representing the message. Returns an empty + string if the message cannot be formatted. + """ if "tool_calls" in message: for tool_call in message["tool_calls"]: function_name = tool_call["function"]["name"] diff --git a/libs/core/llmstudio_core/providers/provider.py b/libs/core/llmstudio_core/providers/provider.py index 07e0383..69a673d 100644 --- a/libs/core/llmstudio_core/providers/provider.py +++ b/libs/core/llmstudio_core/providers/provider.py @@ -152,8 +152,38 @@ async def achat( parameters: Optional[dict] = {}, **kwargs, ): - - """Makes a chat connection with the provider's API""" + """ + Asynchronously establishes a chat connection with the provider’s API, handling retries, + request validation, and streaming response options. + + Parameters + ---------- + chat_input : Any + The input data for the chat request, such as a string or dictionary, to be sent to the API. + model : str + The identifier of the model to be used for the chat request. + is_stream : Optional[bool], default=False + Flag to indicate if the response should be streamed. If True, returns an async generator + for streaming content; otherwise, returns the complete response. + retries : Optional[int], default=0 + Number of retry attempts on error. Retries will be attempted for specific HTTP errors like rate limits. + parameters : Optional[dict], default={} + Additional configuration parameters for the request, such as temperature or max tokens. + **kwargs + Additional keyword arguments to customize the request. + + Returns + ------- + Union[AsyncGenerator, Any] + - If `is_stream` is True, returns an async generator yielding response chunks. + - If `is_stream` is False, returns the first complete response chunk. + + Raises + ------ + ProviderError + - Raised if the request validation fails or if all retry attempts are exhausted. + - Also raised for unexpected exceptions during request handling. + """ try: request = self.validate_request( dict( @@ -198,8 +228,38 @@ def chat( parameters: Optional[dict] = {}, **kwargs, ): - - """Makes a chat connection with the provider's API""" + """ + Establishes a chat connection with the provider’s API, handling retries, request validation, + and streaming response options. + + Parameters + ---------- + chat_input : Any + The input data for the chat request, often a string or dictionary, to be sent to the API. + model : str + The model identifier for selecting the model used in the chat request. + is_stream : Optional[bool], default=False + Flag to indicate if the response should be streamed. If True, the function returns a generator + for streaming content. Otherwise, it returns the complete response. + retries : Optional[int], default=0 + Number of retry attempts on error. Retries will be attempted on specific HTTP errors like rate limits. + parameters : Optional[dict], default={} + Additional configuration parameters for the request, such as temperature or max tokens. + **kwargs + Additional keyword arguments that can be passed to customize the request. + + Returns + ------- + Union[Generator, Any] + - If `is_stream` is True, returns a generator that yields chunks of the response. + - If `is_stream` is False, returns the first complete response chunk. + + Raises + ------ + ProviderError + - Raised if the request validation fails or if the request fails after the specified number of retries. + - Also raised on other unexpected exceptions during request handling. + """ try: request = self.validate_request( dict( @@ -238,7 +298,28 @@ def chat( async def ahandle_response( self, request: ChatRequest, response: AsyncGenerator, start_time: float ) -> AsyncGenerator[str, None]: - """Handles the response from an API""" + """ + Asynchronously handles the response from an API, processing response chunks for either + streaming or non-streaming responses. + + Buffers response chunks for non-streaming responses to output one single message. For streaming responses sends incremental chunks. + + Parameters + ---------- + request : ChatRequest + The chat request object, which includes input data, model name, and streaming options. + response : AsyncGenerator + The async generator yielding response chunks from the API. + start_time : float + The timestamp when the response handling started, used for latency calculations. + + Yields + ------ + Union[ChatCompletionChunk, ChatCompletion] + - If `request.is_stream` is True, yields `ChatCompletionChunk` objects with incremental + response chunks for streaming. + - If `request.is_stream` is False, yields a final `ChatCompletion` object after processing all chunks. + """ first_token_time = None previous_token_time = None token_times = [] @@ -294,7 +375,7 @@ async def ahandle_response( chunks = [chunk[0] if isinstance(chunk, tuple) else chunk for chunk in chunks] model = next(chunk["model"] for chunk in chunks if chunk.get("model")) - response, output_string = self.join_chunks(chunks, request) + response, output_string = self.join_chunks(chunks) metrics = self.calculate_metrics( request.chat_input, @@ -346,7 +427,29 @@ async def ahandle_response( def handle_response( self, request: ChatRequest, response: Generator, start_time: float ) -> Generator: - """Handles the response from an API""" + """ + Processes API response chunks to build a structured, complete response, yielding + each chunk if streaming is enabled. + + If streaming, each chunk is yielded as soon as it’s processed. Otherwise, all chunks + are combined and yielded as a single response at the end. + + Parameters + ---------- + request : ChatRequest + The original request details, including model, input, and streaming preference. + response : Generator + A generator yielding partial response chunks from the API. + start_time : float + The start time for measuring response timing. + + Yields + ------ + Union[ChatCompletionChunk, ChatCompletion] + If streaming (`is_stream=True`), yields each `ChatCompletionChunk` as it’s processed. + Otherwise, yields a single `ChatCompletion` with the full response data. + + """ first_token_time = None previous_token_time = None token_times = [] @@ -402,7 +505,7 @@ def handle_response( chunks = [chunk[0] if isinstance(chunk, tuple) else chunk for chunk in chunks] model = next(chunk["model"] for chunk in chunks if chunk.get("model")) - response, output_string = self.join_chunks(chunks, request) + response, output_string = self.join_chunks(chunks) metrics = self.calculate_metrics( request.chat_input, @@ -451,7 +554,29 @@ def handle_response( else: yield ChatCompletion(**response) - def join_chunks(self, chunks, request): + def join_chunks(self, chunks): + """ + Combine multiple response chunks from the model into a single, structured response. + Handles tool calls, function calls, and standard text completion based on the + purpose indicated by the final chunk. + + Parameters + ---------- + chunks : List[Dict] + A list of partial responses (chunks) from the model. + + Returns + ------- + Tuple[ChatCompletion, str] + - `ChatCompletion`: The structured response based on the type of completion + (tool calls, function call, or text). + - `str`: The concatenated content or arguments, depending on the completion type. + + Raises + ------ + Exception + If there is an issue constructing the response, an exception is raised. + """ finish_reason = chunks[-1].get("choices")[0].get("finish_reason") if finish_reason == "tool_calls": @@ -612,7 +737,42 @@ def calculate_metrics( token_times: Tuple[float, ...], token_count: int, ) -> Dict[str, Any]: - """Calculates metrics based on token times and output""" + """ + Calculates performance and cost metrics for a model response based on timing + information, token counts, and model-specific costs. + + Parameters + ---------- + input : Any + The input provided to the model, used to determine input token count. + output : Any + The output generated by the model, used to determine output token count. + model : str + The model identifier, used to retrieve model-specific configuration and costs. + start_time : float + The timestamp marking the start of the model response. + end_time : float + The timestamp marking the end of the model response. + first_token_time : float + The timestamp when the first token was received, used for latency calculations. + token_times : Tuple[float, ...] + A tuple of time intervals between received tokens, used for inter-token latency. + token_count : int + The total number of tokens processed in the response. + + Returns + ------- + Dict[str, Any] + A dictionary containing calculated metrics, including: + - `input_tokens`: Number of tokens in the input. + - `output_tokens`: Number of tokens in the output. + - `total_tokens`: Total token count (input + output). + - `cost_usd`: Total cost of the response in USD. + - `latency_s`: Total time taken for the response, in seconds. + - `time_to_first_token_s`: Time to receive the first token, in seconds. + - `inter_token_latency_s`: Average time between tokens, in seconds. If `token_times` is empty sets it to 0. + - `tokens_per_second`: Processing rate of tokens per second. + """ model_config = self.config.models[model] input_tokens = len(self.tokenizer.encode(self.input_to_string(input))) output_tokens = len(self.tokenizer.encode(self.output_to_string(output))) @@ -628,17 +788,42 @@ def calculate_metrics( "cost_usd": input_cost + output_cost, "latency_s": total_time, "time_to_first_token_s": first_token_time - start_time, - "inter_token_latency_s": sum(token_times) / len(token_times), - "tokens_per_second": token_count / total_time, + "inter_token_latency_s": sum(token_times) / len(token_times) + if token_times + else 0, + "tokens_per_second": token_count / total_time + if token_times + else 1 / total_time, } def calculate_cost( self, token_count: int, token_cost: Union[float, List[Dict[str, Any]]] ) -> float: + """ + Calculates the cost for a given number of tokens based on a fixed cost per token + or a variable rate structure. + + If `token_cost` is a fixed float, the total cost is `token_count * token_cost`. + If `token_cost` is a list, it checks each range and calculates cost based on the applicable range's rate. + + Parameters + ---------- + token_count : int + The total number of tokens for which the cost is being calculated. + token_cost : Union[float, List[Dict[str, Any]]] + Either a fixed cost per token (as a float) or a list of dictionaries defining + variable cost ranges. Each dictionary in the list represents a range with + 'range' (a tuple of minimum and maximum token counts) and 'cost' (cost per token) keys. + + Returns + ------- + float + The calculated cost based on the token count and cost structure. + """ if isinstance(token_cost, list): for cost_range in token_cost: if token_count >= cost_range.range[0] and ( - token_count <= cost_range.range[1] or cost_range.range[1] is None + cost_range.range[1] is None or token_count <= cost_range.range[1] ): return cost_range.cost * token_count else: @@ -646,6 +831,23 @@ def calculate_cost( return 0 def input_to_string(self, input): + """ + Converts an input, which can be a string or a structured list of messages, into a single concatenated string. + + Parameters + ---------- + input : Any + The input data to be converted. This can be: + - A simple string, which is returned as-is. + - A list of message dictionaries, where each dictionary may contain `content`, `role`, + and nested items like `text` or `image_url`. + + Returns + ------- + str + A concatenated string representing the text content of all messages, + including text and URLs from image content if present. + """ if isinstance(input, str): return input else: @@ -667,6 +869,23 @@ def input_to_string(self, input): return "".join(result) def output_to_string(self, output): + """ + Extracts and returns the content or arguments from the output based on + the `finish_reason` of the first choice in `output`. + + Parameters + ---------- + output : Any + The model output object, expected to have a `choices` attribute that should contain a `finish_reason` indicating the type of output + ("stop", "tool_calls", or "function_call") and corresponding content or arguments. + + Returns + ------- + str + - If `finish_reason` is "stop": Returns the message content. + - If `finish_reason` is "tool_calls": Returns the arguments for the first tool call. + - If `finish_reason` is "function_call": Returns the arguments for the function call. + """ if output.choices[0].finish_reason == "stop": return output.choices[0].message.content elif output.choices[0].finish_reason == "tool_calls": diff --git a/libs/core/pyproject.toml b/libs/core/pyproject.toml index de5a22c..0847b86 100644 --- a/libs/core/pyproject.toml +++ b/libs/core/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "llmstudio-core" -version = "1.0.2" +version = "1.0.3a0" description = "LLMStudio core capabilities for routing llm calls for any vendor. No proxy server required. For that use llmstudio[proxy]" authors = ["Cláudio Lemos "] license = "MIT" diff --git a/libs/core/tests/unit_tests/conftest.py b/libs/core/tests/unit_tests/conftest.py index 23f070e..1d1939a 100644 --- a/libs/core/tests/unit_tests/conftest.py +++ b/libs/core/tests/unit_tests/conftest.py @@ -1,6 +1,7 @@ from unittest.mock import MagicMock import pytest +from llmstudio_core.providers.azure import AzureProvider from llmstudio_core.providers.provider import ProviderCore @@ -11,14 +12,6 @@ async def aparse_response(self, response, **kwargs): def parse_response(self, response, **kwargs): return response - def chat(self, chat_input, model, **kwargs): - # Mock the response to match expected structure - return MagicMock(choices=[MagicMock(finish_reason="stop")]) - - async def achat(self, chat_input, model, **kwargs): - # Mock the response to match expected structure - return MagicMock(choices=[MagicMock(finish_reason="stop")]) - def output_to_string(self, output): # Handle string inputs if isinstance(output, str): @@ -27,6 +20,24 @@ def output_to_string(self, output): return output.choices[0].message.content return "" + def validate_request(self, request): + # For testing, simply return the request + return request + + async def agenerate_client(self, request): + # For testing, return an async generator + async def async_gen(): + yield {} + + return async_gen() + + def generate_client(self, request): + # For testing, return a generator + def gen(): + yield {} + + return gen() + @staticmethod def _provider_config_name(): return "mock_provider" @@ -42,3 +53,34 @@ def mock_provider(): tokenizer = MagicMock() tokenizer.encode = lambda x: x.split() # Simple tokenizer mock return MockProvider(config=config, tokenizer=tokenizer) + + +class MockAzureProvider(AzureProvider): + async def aparse_response(self, response, **kwargs): + return response + + async def agenerate_client(self, request): + # For testing, return an async generator + async def async_gen(): + yield {} + + return async_gen() + + @staticmethod + def _provider_config_name(): + return "mock_azure_provider" + + +@pytest.fixture +def mock_azure_provider(): + config = MagicMock() + config.id = "mock_azure_provider" + api_key = "key" + api_endpoint = "http://azureopenai.com" + api_version = "2025-01-01-preview" + return MockAzureProvider( + config=config, + api_endpoint=api_endpoint, + api_key=api_key, + api_version=api_version, + ) diff --git a/libs/core/tests/unit_tests/test_azure.py b/libs/core/tests/unit_tests/test_azure.py new file mode 100644 index 0000000..28e5f68 --- /dev/null +++ b/libs/core/tests/unit_tests/test_azure.py @@ -0,0 +1,232 @@ +from unittest.mock import MagicMock + + +class TestParseResponse: + def test_tool_response_handling(self, mock_azure_provider): + + mock_azure_provider.is_llama = True + mock_azure_provider.has_tools = True + mock_azure_provider.has_functions = False + + mock_azure_provider.handle_tool_response = MagicMock( + return_value=iter(["chunk1", None, "chunk2", None, "chunk3"]) + ) + + response = iter(["irrelevant"]) + + results = list(mock_azure_provider.parse_response(response)) + + assert results == ["chunk1", "chunk2", "chunk3"] + mock_azure_provider.handle_tool_response.assert_called_once_with(response) + + def test_direct_response_handling_with_choices(self, mock_azure_provider): + mock_azure_provider.is_llama = False + + chunk1 = MagicMock() + chunk1.model_dump.return_value = {"choices": ["choice1", "choice2"]} + chunk2 = MagicMock() + chunk2.model_dump.return_value = {"choices": ["choice2"]} + response = iter([chunk1, chunk2]) + + results = list(mock_azure_provider.parse_response(response)) + + assert results == [ + {"choices": ["choice1", "choice2"]}, + {"choices": ["choice2"]}, + ] + chunk1.model_dump.assert_called_once() + chunk2.model_dump.assert_called_once() + + def test_direct_response_handling_without_choices(self, mock_azure_provider): + mock_azure_provider.is_llama = False + + chunk1 = MagicMock() + chunk1.model_dump.return_value = {"key": "value"} + chunk2 = MagicMock() + chunk2.model_dump.return_value = {"another_key": "another_value"} + response = iter([chunk1, chunk2]) + + results = list(mock_azure_provider.parse_response(response)) + + assert results == [] + chunk1.model_dump.assert_called_once() + chunk2.model_dump.assert_called_once() + + +class TestFormatMessage: + def test_format_message_tool_calls(self, mock_azure_provider): + message = { + "tool_calls": [ + { + "function": { + "name": "example_tool", + "arguments": '{"arg1": "value1"}', + } + } + ] + } + result = mock_azure_provider.format_message(message) + expected = """ + <|start_header_id|>assistant<|end_header_id|> + {"arg1": "value1"} + <|eom_id|> + """ + assert result.strip() == expected.strip() + + def test_format_message_tool_call_id(self, mock_azure_provider): + message = {"tool_call_id": "123", "content": "This is the tool response."} + result = mock_azure_provider.format_message(message) + expected = """ + <|start_header_id|>ipython<|end_header_id|> + This is the tool response. + <|eot_id|> + """ + assert result.strip() == expected.strip() + + def test_format_message_function_call(self, mock_azure_provider): + message = { + "function_call": { + "name": "example_function", + "arguments": '{"arg1": "value1"}', + } + } + result = mock_azure_provider.format_message(message) + expected = """ + <|start_header_id|>assistant<|end_header_id|> + {"arg1": "value1"} + <|eom_id|> + """ + assert result.strip() == expected.strip() + + def test_format_message_user_message(self, mock_azure_provider): + message = {"role": "user", "content": "This is a user message."} + result = mock_azure_provider.format_message(message) + expected = """ + <|start_header_id|>user<|end_header_id|> + This is a user message. + <|eot_id|> + """ + assert result.strip() == expected.strip() + + def test_format_message_assistant_message(self, mock_azure_provider): + message = {"role": "assistant", "content": "This is an assistant message."} + result = mock_azure_provider.format_message(message) + expected = """ + <|start_header_id|>assistant<|end_header_id|> + This is an assistant message. + <|eot_id|> + """ + assert result.strip() == expected.strip() + + def test_format_message_function_response(self, mock_azure_provider): + message = {"role": "function", "content": "This is the function response."} + result = mock_azure_provider.format_message(message) + expected = """ + <|start_header_id|>ipython<|end_header_id|> + This is the function response. + <|eot_id|> + """ + assert result.strip() == expected.strip() + + def test_format_message_empty_message(self, mock_azure_provider): + message = {"role": "user", "content": None} + result = mock_azure_provider.format_message(message) + expected = "" + assert result == expected + + +class TestGenerateClient: + def test_generate_client_with_tools_and_functions(self, mock_azure_provider): + mock_azure_provider.prepare_messages = MagicMock( + return_value="prepared_messages" + ) + mock_azure_provider._client.chat.completions.create = MagicMock( + return_value="mock_response" + ) + + request = MagicMock() + request.model = "gpt-4" + request.parameters = { + "tools": ["tool1", "tool2"], + "functions": ["function1", "function2"], + "other_param": "value", + } + + result = mock_azure_provider.generate_client(request) + + expected_args = { + "model": "gpt-4", + "messages": "prepared_messages", + "stream": True, + "tools": ["tool1", "tool2"], + "tool_choice": "auto", + "functions": ["function1", "function2"], + "function_call": "auto", + "other_param": "value", + } + + assert result == "mock_response" + mock_azure_provider.prepare_messages.assert_called_once_with(request) + mock_azure_provider._client.chat.completions.create.assert_called_once_with( + **expected_args + ) + + def test_generate_client_without_tools_or_functions(self, mock_azure_provider): + mock_azure_provider.prepare_messages = MagicMock( + return_value="prepared_messages" + ) + mock_azure_provider._client.chat.completions.create = MagicMock( + return_value="mock_response" + ) + + request = MagicMock() + request.model = "gpt-4" + request.parameters = {"other_param": "value"} + + result = mock_azure_provider.generate_client(request) + + expected_args = { + "model": "gpt-4", + "messages": "prepared_messages", + "stream": True, + "other_param": "value", + } + + assert result == "mock_response" + mock_azure_provider.prepare_messages.assert_called_once_with(request) + mock_azure_provider._client.chat.completions.create.assert_called_once_with( + **expected_args + ) + + def test_generate_client_with_llama_model(self, mock_azure_provider): + mock_azure_provider.prepare_messages = MagicMock( + return_value="prepared_messages" + ) + mock_azure_provider._client.chat.completions.create = MagicMock( + return_value="mock_response" + ) + + request = MagicMock() + request.model = "llama-2" + request.parameters = { + "tools": ["tool1"], + "functions": ["function1"], + "other_param": "value", + } + + result = mock_azure_provider.generate_client(request) + + expected_args = { + "model": "llama-2", + "messages": "prepared_messages", + "stream": True, + "tools": ["tool1"], + "functions": ["function1"], + "other_param": "value", + } + + assert result == "mock_response" + mock_azure_provider.prepare_messages.assert_called_once_with(request) + mock_azure_provider._client.chat.completions.create.assert_called_once_with( + **expected_args + ) diff --git a/libs/core/tests/unit_tests/test_azure_build.py b/libs/core/tests/unit_tests/test_azure_build.py new file mode 100644 index 0000000..17d73f9 --- /dev/null +++ b/libs/core/tests/unit_tests/test_azure_build.py @@ -0,0 +1,279 @@ +from unittest.mock import MagicMock, patch + + +class TestBuildLlamaSystemMessage: + def test_build_llama_system_message_with_existing_sm(self, mock_azure_provider): + mock_azure_provider.build_tool_instructions = MagicMock( + return_value="Tool Instructions" + ) + mock_azure_provider.build_function_instructions = MagicMock( + return_value="\nFunction Instructions" + ) + + openai_message = [ + {"role": "user", "content": "Hello"}, + {"role": "system", "content": "Custom system message"}, + ] + llama_message = "Initial message" + tools = ["Tool1", "Tool2"] + functions = ["Function1"] + + result = mock_azure_provider.build_llama_system_message( + openai_message, llama_message, tools, functions + ) + + expected = ( + "Initial message\n" + " <|start_header_id|>system<|end_header_id|>\n" + " Custom system message\n" + " Tool Instructions\nFunction Instructions\n<|eot_id|>" # identation here exists because in Python when adding a newline to a triple quote string it keeps identation + ) + assert result == expected + mock_azure_provider.build_tool_instructions.assert_called_once_with(tools) + mock_azure_provider.build_function_instructions.assert_called_once_with( + functions + ) + + def test_build_llama_system_message_with_default_sm(self, mock_azure_provider): + mock_azure_provider.build_tool_instructions = MagicMock( + return_value="Tool Instructions" + ) + mock_azure_provider.build_function_instructions = MagicMock( + return_value="\nFunction Instructions" + ) + + openai_message = [{"role": "user", "content": "Hello"}] + llama_message = "Initial message" + tools = ["Tool1"] + functions = [] + + result = mock_azure_provider.build_llama_system_message( + openai_message, llama_message, tools, functions + ) + + expected = ( + "Initial message\n" + " <|start_header_id|>system<|end_header_id|>\n" + " You are a helpful AI assistant.\n" + " Tool Instructions\n<|eot_id|>" + ) + assert result == expected + mock_azure_provider.build_tool_instructions.assert_called_once_with(tools) + mock_azure_provider.build_function_instructions.assert_not_called() + + def test_build_llama_system_message_without_tools_or_functions( + self, mock_azure_provider + ): + mock_azure_provider.build_tool_instructions = MagicMock() + mock_azure_provider.build_function_instructions = MagicMock() + + openai_message = [{"role": "system", "content": "Minimal system message"}] + llama_message = "Initial message" + tools = [] + functions = [] + + result = mock_azure_provider.build_llama_system_message( + openai_message, llama_message, tools, functions + ) + + expected = ( + "Initial message\n" + " <|start_header_id|>system<|end_header_id|>\n" + " Minimal system message\n \n<|eot_id|>" + ) + assert result == expected + mock_azure_provider.build_tool_instructions.assert_not_called() + mock_azure_provider.build_function_instructions.assert_not_called() + + +class TestBuildInstructions: + def test_build_tool_instructions(self, mock_azure_provider): + tools = [ + { + "type": "function", + "function": { + "name": "python_repl_ast", + "description": "execute Python code", + "parameters": {"query": "string"}, + }, + }, + { + "type": "function", + "function": { + "name": "data_lookup", + "description": "retrieve data from a database", + "parameters": {"database": "string", "query": "string"}, + }, + }, + ] + + result = mock_azure_provider.build_tool_instructions(tools) + + expected = """ + You have access to the following tools: + Use the function 'python_repl_ast' to 'execute Python code': +Parameters format: +{ + "query": "string" +} + +Use the function 'data_lookup' to 'retrieve data from a database': +Parameters format: +{ + "database": "string", + "query": "string" +} + + +If you choose to use a function to produce this response, ONLY reply in the following format with no prefix or suffix: +§{"type": "function", "name": "FUNCTION_NAME", "parameters": {"PARAMETER_NAME": PARAMETER_VALUE}} +IMPORTANT: IT IS VITAL THAT YOU NEVER ADD A PREFIX OR A SUFFIX TO THE FUNCTION CALL. + +Here is an example of the output I desiere when performing function call: +§{"type": "function", "name": "python_repl_ast", "parameters": {"query": "print(df.shape)"}} +NOTE: There is no prefix before the symbol '§' and nothing comes after the call is done. + + Reminder: + - Function calls MUST follow the specified format. + - Only call one function at a time. + - Required parameters MUST be specified. + - Put the entire function call reply on one line. + - If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls. + - If you have already called a tool and got the response for the users question please reply with the response. + """ + assert result.strip() == expected.strip() + + def test_build_function_instructions(self, mock_azure_provider): + functions = [ + { + "name": "python_repl_ast", + "description": "execute Python code", + "parameters": {"query": "string"}, + }, + { + "name": "data_lookup", + "description": "retrieve data from a database", + "parameters": {"database": "string", "query": "string"}, + }, + ] + + result = mock_azure_provider.build_function_instructions(functions) + + expected = """ +You have access to the following functions: +Use the function 'python_repl_ast' to: 'execute Python code' +{ + "query": "string" +} + +Use the function 'data_lookup' to: 'retrieve data from a database' +{ + "database": "string", + "query": "string" +} + + +If you choose to use a function to produce this response, ONLY reply in the following format with no prefix or suffix: +§{"type": "function", "name": "FUNCTION_NAME", "parameters": {"PARAMETER_NAME": PARAMETER_VALUE}} + +Here is an example of the output I desiere when performing function call: +§{"type": "function", "name": "python_repl_ast", "parameters": {"query": "print(df.shape)"}} + +Reminder: +- Function calls MUST follow the specified format. +- Only call one function at a time. +- NEVER call more than one function at a time. +- Required parameters MUST be specified. +- Put the entire function call reply on one line. +- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls. +- If you have already called a function and got the response for the user's question, please reply with the response. +""" + + assert result.strip() == expected.strip() + + +class TestBuildLlamaConversation: + def test_build_llama_conversation_with_nested_messages(self, mock_azure_provider): + mock_azure_provider.format_message = MagicMock( + side_effect=lambda msg: f"[formatted:{msg['content']}]" + ) + + openai_message = [ + { + "role": "user", + "content": "[{'content': 'nested message 1'}, {'content': 'nested message 2'}]", + }, + {"role": "assistant", "content": "assistant reply"}, + ] + llama_message = "Initial message: " + + result = mock_azure_provider.build_llama_conversation( + openai_message, llama_message + ) + + expected = "Initial message: [formatted:nested message 1][formatted:nested message 2][formatted:assistant reply]" + + assert result == expected + mock_azure_provider.format_message.assert_any_call( + {"content": "nested message 1"} + ) + mock_azure_provider.format_message.assert_any_call( + {"content": "nested message 2"} + ) + mock_azure_provider.format_message.assert_any_call( + {"role": "assistant", "content": "assistant reply"} + ) + + def test_build_llama_conversation_with_invalid_nested_content( + self, mock_azure_provider + ): + mock_azure_provider.format_message = MagicMock( + side_effect=lambda msg: f"[formatted:{msg['content']}]" + ) + + openai_message = [ + {"role": "user", "content": "[invalid json/dict]"}, + {"role": "assistant", "content": "assistant reply"}, + ] + llama_message = "Initial message: " + + with patch("ast.literal_eval", side_effect=ValueError) as mock_literal_eval: + result = mock_azure_provider.build_llama_conversation( + openai_message, llama_message + ) + + expected = "Initial message: [formatted:[invalid json/dict]][formatted:assistant reply]" + + assert result == expected + mock_azure_provider.format_message.assert_any_call( + {"role": "user", "content": "[invalid json/dict]"} + ) + mock_azure_provider.format_message.assert_any_call( + {"role": "assistant", "content": "assistant reply"} + ) + + mock_literal_eval.assert_called_once_with("[invalid json/dict]") + + def test_build_llama_conversation_skipping_system_messages( + self, mock_azure_provider + ): + mock_azure_provider.format_message = MagicMock( + side_effect=lambda msg: f"[formatted:{msg['content']}]" + ) + + openai_message = [ + {"role": "system", "content": "system message"}, + {"role": "user", "content": "user message"}, + ] + llama_message = "Initial message: " + + result = mock_azure_provider.build_llama_conversation( + openai_message, llama_message + ) + + expected = "Initial message: [formatted:user message]" + + assert result == expected + mock_azure_provider.format_message.assert_any_call( + {"role": "user", "content": "user message"} + ) diff --git a/libs/core/tests/unit_tests/test_provider.py b/libs/core/tests/unit_tests/test_provider.py index 118367a..42396c9 100644 --- a/libs/core/tests/unit_tests/test_provider.py +++ b/libs/core/tests/unit_tests/test_provider.py @@ -1,30 +1,39 @@ -from unittest.mock import AsyncMock, MagicMock +from unittest.mock import MagicMock import pytest -from llmstudio_core.providers.provider import ChatRequest, ProviderError +from llmstudio_core.providers.provider import ChatRequest, ProviderError, time -request = ChatRequest(chat_input="Hello", model="test_model") +request = ChatRequest(chat_input="Hello World", model="test_model") -def test_chat(mock_provider): - mock_provider.generate_client = MagicMock(return_value=MagicMock()) - mock_provider.handle_response = MagicMock(return_value=iter(["response"])) +def test_chat_response_non_stream(mock_provider): + mock_provider.validate_request = MagicMock() + mock_provider.validate_model = MagicMock() + mock_provider.generate_client = MagicMock(return_value="mock_response") + mock_provider.handle_response = MagicMock(return_value="final_response") - print(request.model_dump()) - response = mock_provider.chat(request.chat_input, request.model) + response = mock_provider.chat(chat_input="Hello", model="test_model") - assert response is not None + assert response == "final_response" + mock_provider.validate_request.assert_called_once() + mock_provider.validate_model.assert_called_once() -@pytest.mark.asyncio -async def test_achat(mock_provider): - mock_provider.agenerate_client = AsyncMock(return_value=AsyncMock()) - mock_provider.ahandle_response = AsyncMock(return_value=AsyncMock()) - - print(request.model_dump()) - response = await mock_provider.achat(request.chat_input, request.model) +def test_chat_streaming_response(mock_provider): + mock_provider.validate_request = MagicMock() + mock_provider.validate_model = MagicMock() + mock_provider.generate_client = MagicMock(return_value="mock_response_stream") + mock_provider.handle_response = MagicMock( + return_value=iter(["streamed_response_1", "streamed_response_2"]) + ) - assert response is not None + response_stream = mock_provider.chat( + chat_input="Hello", model="test_model", is_stream=True + ) + assert next(response_stream) == "streamed_response_1" + assert next(response_stream) == "streamed_response_2" + mock_provider.validate_request.assert_called_once() + mock_provider.validate_model.assert_called_once() def test_validate_model(mock_provider): @@ -36,22 +45,214 @@ def test_validate_model(mock_provider): mock_provider.validate_model(request_invalid) -def test_calculate_metrics(mock_provider): - metrics = mock_provider.calculate_metrics( - input="Hello", - output="World", - model="test_model", - start_time=0, - end_time=1, - first_token_time=0.5, - token_times=(0.1, 0.2), - token_count=2, - ) +def test_join_chunks_finish_reason_stop(mock_provider): + current_time = int(time.time()) + chunks = [ + { + "id": "test_id", + "model": "test_model", + "created": current_time, + "choices": [ + { + "delta": {"content": "Hello, "}, + "finish_reason": None, + "index": 0, + } + ], + }, + { + "id": "test_id", + "model": "test_model", + "created": current_time, + "choices": [ + { + "delta": {"content": "world!"}, + "finish_reason": "stop", + "index": 0, + } + ], + }, + ] + response, output_string = mock_provider.join_chunks(chunks) + + assert output_string == "Hello, world!" + assert response.choices[0].message.content == "Hello, world!" + + +def test_join_chunks_finish_reason_function_call(mock_provider): + current_time = int(time.time()) + chunks = [ + { + "id": "test_id", + "model": "test_model", + "created": current_time, + "choices": [ + { + "delta": { + "function_call": {"name": "my_function", "arguments": "arg1"} + }, + "finish_reason": None, + "index": 0, + } + ], + }, + { + "id": "test_id", + "model": "test_model", + "created": current_time, + "choices": [ + { + "delta": {"function_call": {"arguments": "arg2"}}, + "finish_reason": "function_call", + "index": 0, + } + ], + }, + { + "id": "test_id", + "model": "test_model", + "created": current_time, + "choices": [ + { + "delta": {"function_call": {"arguments": "}"}}, + "finish_reason": "function_call", + "index": 0, + } + ], + }, + ] + response, output_string = mock_provider.join_chunks(chunks) + + assert output_string == "arg1arg2" + assert response.choices[0].message.function_call.arguments == "arg1arg2" + assert response.choices[0].message.function_call.name == "my_function" + + +def test_join_chunks_tool_calls(mock_provider): + current_time = int(time.time()) + + chunks = [ + { + "id": "test_id_1", + "model": "test_model", + "created": current_time, + "choices": [ + { + "delta": { + "tool_calls": [ + { + "id": "tool_1", + "index": 0, + "function": { + "name": "search_tool", + "arguments": '{"query": "weather', + }, + "type": "function", + } + ] + }, + "finish_reason": None, + "index": 0, + } + ], + }, + { + "id": "test_id_2", + "model": "test_model", + "created": current_time, + "choices": [ + { + "delta": { + "tool_calls": [ + { + "id": "tool_1", + "index": 0, + "function": { + "name": "search_tool", + "arguments": ' details"}', + }, + } + ] + }, + "finish_reason": "tool_calls", + "index": 0, + } + ], + }, + ] + + response, output_string = mock_provider.join_chunks(chunks) + + assert output_string == "['search_tool', '{\"query\": \"weather details\"}']" + + assert response.object == "chat.completion" + assert response.choices[0].finish_reason == "tool_calls" + tool_call = response.choices[0].message.tool_calls[0] + + assert tool_call.id == "tool_1" + assert tool_call.function.name == "search_tool" + assert tool_call.function.arguments == '{"query": "weather details"}' + assert tool_call.type == "function" + + +def test_input_to_string_with_string(mock_provider): + input_data = "Hello, world!" + assert mock_provider.input_to_string(input_data) == "Hello, world!" + + +def test_input_to_string_with_list_of_text_messages(mock_provider): + input_data = [ + {"content": "Hello"}, + {"content": " world!"}, + ] + assert mock_provider.input_to_string(input_data) == "Hello world!" + + +def test_input_to_string_with_list_of_text_and_url(mock_provider): + input_data = [ + {"role": "user", "content": [{"type": "text", "text": "Hello "}]}, + { + "role": "user", + "content": [ + { + "type": "image_url", + "image_url": {"url": "http://example.com/image.jpg"}, + } + ], + }, + {"role": "user", "content": [{"type": "text", "text": " world!"}]}, + ] + expected_output = "Hello http://example.com/image.jpg world!" + assert mock_provider.input_to_string(input_data) == expected_output + + +def test_input_to_string_with_mixed_roles_and_missing_content(mock_provider): + input_data = [ + {"role": "assistant", "content": "Admin text;"}, + {"role": "user", "content": [{"type": "text", "text": "User text"}]}, + { + "role": "user", + "content": [ + { + "type": "image_url", + "image_url": {"url": "http://example.com/another.jpg"}, + } + ], + }, + ] + expected_output = "Admin text;User texthttp://example.com/another.jpg" + assert mock_provider.input_to_string(input_data) == expected_output + + +def test_input_to_string_with_missing_content_key(mock_provider): + input_data = [ + {"role": "user"}, + {"role": "user", "content": [{"type": "text", "text": "Hello again"}]}, + ] + expected_output = "Hello again" + assert mock_provider.input_to_string(input_data) == expected_output + - assert metrics["input_tokens"] == pytest.approx(1) - assert metrics["output_tokens"] == pytest.approx(1) - assert metrics["cost_usd"] == pytest.approx(0.03) - assert metrics["latency_s"] == pytest.approx(1) - assert metrics["time_to_first_token_s"] == pytest.approx(0.5) - assert metrics["inter_token_latency_s"] == pytest.approx(0.15) - assert metrics["tokens_per_second"] == pytest.approx(2) +def test_input_to_string_with_empty_list(mock_provider): + input_data = [] + assert mock_provider.input_to_string(input_data) == "" diff --git a/libs/core/tests/unit_tests/test_provider_costs_and_metrics.py b/libs/core/tests/unit_tests/test_provider_costs_and_metrics.py new file mode 100644 index 0000000..fb54d60 --- /dev/null +++ b/libs/core/tests/unit_tests/test_provider_costs_and_metrics.py @@ -0,0 +1,128 @@ +from unittest.mock import MagicMock + + +def test_calculate_metrics(mock_provider): + + metrics = mock_provider.calculate_metrics( + input="Hello", + output="Hello World", + model="test_model", + start_time=0.0, + end_time=1.0, + first_token_time=0.5, + token_times=(0.1,), + token_count=2, + ) + + assert metrics["input_tokens"] == 1 + assert metrics["output_tokens"] == 2 + assert metrics["total_tokens"] == 3 + assert metrics["cost_usd"] == 0.01 * 1 + 0.02 * 2 # input_cost + output_cost + assert metrics["latency_s"] == 1.0 # end_time - start_time + assert ( + metrics["time_to_first_token_s"] == 0.5 - 0.0 + ) # first_token_time - start_time + assert metrics["inter_token_latency_s"] == 0.1 # Average of token_times + assert metrics["tokens_per_second"] == 2 / 1.0 # token_count / total_time + + +def test_calculate_metrics_single_token(mock_provider): + + metrics = mock_provider.calculate_metrics( + input="Hello", + output="World", + model="test_model", + start_time=0.0, + end_time=1.0, + first_token_time=0.5, + token_times=(), + token_count=1, + ) + + assert metrics["input_tokens"] == 1 + assert metrics["output_tokens"] == 1 + assert metrics["total_tokens"] == 2 + assert metrics["cost_usd"] == 0.01 * 1 + 0.02 * 1 + assert metrics["latency_s"] == 1.0 + assert metrics["time_to_first_token_s"] == 0.5 - 0.0 + assert metrics["inter_token_latency_s"] == 0 + assert metrics["tokens_per_second"] == 1 / 1.0 + + +def test_calculate_cost_fixed_cost(mock_provider): + fixed_cost = 0.02 + token_count = 100 + expected_cost = token_count * fixed_cost + assert mock_provider.calculate_cost(token_count, fixed_cost) == expected_cost + + +def test_calculate_cost_variable_cost(mock_provider): + cost_range_1 = MagicMock() + cost_range_1.range = (0, 50) + cost_range_1.cost = 0.01 + + cost_range_2 = MagicMock() + cost_range_2.range = (51, 100) + cost_range_2.cost = 0.02 + + variable_cost = [cost_range_1, cost_range_2] + token_count = 75 + expected_cost = token_count * 0.02 + assert mock_provider.calculate_cost(token_count, variable_cost) == expected_cost + + +def test_calculate_cost_variable_cost_higher_range(mock_provider): + cost_range_1 = MagicMock() + cost_range_1.range = (0, 50) + cost_range_1.cost = 0.01 + + cost_range_2 = MagicMock() + cost_range_2.range = (51, 100) + cost_range_2.cost = 0.02 + + cost_range_3 = MagicMock() + cost_range_3.range = (101, None) + cost_range_3.cost = 0.03 + + variable_cost = [cost_range_1, cost_range_2, cost_range_3] + token_count = 150 + expected_cost = token_count * 0.03 + assert mock_provider.calculate_cost(token_count, variable_cost) == expected_cost + + +def test_calculate_cost_variable_cost_no_matching_range(mock_provider): + cost_range_1 = MagicMock() + cost_range_1.range = (0, 50) + cost_range_1.cost = 0.01 + + cost_range_2 = MagicMock() + cost_range_2.range = (51, 100) + cost_range_2.cost = 0.02 + + cost_range_3 = MagicMock() + cost_range_3.range = (101, 150) + cost_range_3.cost = 0.03 + + variable_cost = [cost_range_1, cost_range_2, cost_range_3] + token_count = 200 + expected_cost = 0 + assert mock_provider.calculate_cost(token_count, variable_cost) == expected_cost + + +def test_calculate_cost_variable_cost_no_matching_range_inferior(mock_provider): + cost_range_1 = MagicMock() + cost_range_1.range = (10, 50) + cost_range_1.cost = 0.01 + + cost_range_2 = MagicMock() + cost_range_2.range = (51, 100) + cost_range_2.cost = 0.02 + + cost_range_3 = MagicMock() + cost_range_3.range = (101, 150) + cost_range_3.cost = 0.03 + + variable_cost = [cost_range_1, cost_range_2, cost_range_3] + token_count = 5 + expected_cost = 0 + assert mock_provider.calculate_cost(token_count, variable_cost) == expected_cost diff --git a/libs/core/tests/unit_tests/test_provider_handle_response.py b/libs/core/tests/unit_tests/test_provider_handle_response.py new file mode 100644 index 0000000..ac04e32 --- /dev/null +++ b/libs/core/tests/unit_tests/test_provider_handle_response.py @@ -0,0 +1,281 @@ +from unittest.mock import MagicMock + +import pytest +from llmstudio_core.providers.provider import ChatCompletion, ChatCompletionChunk, time + + +@pytest.mark.asyncio +async def test_ahandle_response_non_streaming(mock_provider): + request = MagicMock( + is_stream=False, chat_input="Hello", model="test_model", parameters={} + ) + response_chunk = { + "choices": [ + { + "delta": {"content": "Non-streamed response"}, + "finish_reason": "stop", + "index": 0, + } + ], + "model": "test_model", + } + start_time = time.time() + + async def mock_aparse_response(*args, **kwargs): + yield response_chunk + + mock_provider.aparse_response = mock_aparse_response + mock_provider.join_chunks = MagicMock( + return_value=( + ChatCompletion( + id="id", + choices=[], + created=0, + model="test_model", + object="chat.completion", + ), + "Non-streamed response", + ) + ) + mock_provider.calculate_metrics = MagicMock(return_value={"input_tokens": 1}) + + response = [] + async for chunk in mock_provider.ahandle_response( + request, mock_aparse_response(), start_time + ): + response.append(chunk) + + assert isinstance(response[0], ChatCompletion) + assert response[0].choices == [] + assert response[0].chat_output == "Non-streamed response" + + +@pytest.mark.asyncio +async def test_ahandle_response_streaming_length(mock_provider): + request = MagicMock( + is_stream=True, chat_input="Hello", model="test_model", parameters={} + ) + response_chunk = { + "choices": [ + { + "delta": {"content": "Streamed response"}, + "finish_reason": "length", + "index": 0, + } + ], + "model": "test_model", + "object": "chat.completion.chunk", + "created": 0, + } + start_time = time.time() + + async def mock_aparse_response(*args, **kwargs): + yield response_chunk + + mock_provider.aparse_response = mock_aparse_response + mock_provider.join_chunks = MagicMock( + return_value=( + ChatCompletion( + id="id", + choices=[], + created=0, + model="test_model", + object="chat.completion", + ), + "Streamed response", + ) + ) + mock_provider.calculate_metrics = MagicMock(return_value={"input_tokens": 1}) + + response = [] + async for chunk in mock_provider.ahandle_response( + request, mock_aparse_response(), start_time + ): + response.append(chunk) + + assert isinstance(response[0], ChatCompletionChunk) + assert response[0].chat_output_stream == "Streamed response" + + +@pytest.mark.asyncio +async def test_ahandle_response_streaming_stop(mock_provider): + request = MagicMock( + is_stream=True, chat_input="Hello", model="test_model", parameters={} + ) + response_chunk = { + "choices": [ + { + "delta": {"content": "Streamed response"}, + "finish_reason": "stop", + "index": 0, + } + ], + "model": "test_model", + "object": "chat.completion.chunk", + "created": 0, + } + start_time = time.time() + + async def mock_aparse_response(*args, **kwargs): + yield response_chunk + + mock_provider.aparse_response = mock_aparse_response + mock_provider.join_chunks = MagicMock( + return_value=( + ChatCompletion( + id="id", + choices=[], + created=0, + model="test_model", + object="chat.completion", + ), + "Streamed response", + ) + ) + mock_provider.calculate_metrics = MagicMock(return_value={"input_tokens": 1}) + + response = [] + async for chunk in mock_provider.ahandle_response( + request, mock_aparse_response(), start_time + ): + response.append(chunk) + + assert isinstance(response[0], ChatCompletionChunk) + assert response[0].chat_output == "Streamed response" + + +def test_handle_response_non_streaming(mock_provider): + request = MagicMock( + is_stream=False, chat_input="Hello", model="test_model", parameters={} + ) + response_chunk = { + "choices": [ + { + "delta": {"content": "Non-streamed response"}, + "finish_reason": "stop", + "index": 0, + } + ], + "model": "test_model", + } + start_time = time.time() + + def mock_parse_response(*args, **kwargs): + yield response_chunk + + mock_provider.aparse_response = mock_parse_response + mock_provider.join_chunks = MagicMock( + return_value=( + ChatCompletion( + id="id", + choices=[], + created=0, + model="test_model", + object="chat.completion", + ), + "Non-streamed response", + ) + ) + mock_provider.calculate_metrics = MagicMock(return_value={"input_tokens": 1}) + + response = [] + for chunk in mock_provider.handle_response( + request, mock_parse_response(), start_time + ): + response.append(chunk) + + assert isinstance(response[0], ChatCompletion) + assert response[0].choices == [] + assert response[0].chat_output == "Non-streamed response" + + +def test_handle_response_streaming_length(mock_provider): + request = MagicMock( + is_stream=True, chat_input="Hello", model="test_model", parameters={} + ) + response_chunk = { + "choices": [ + { + "delta": {"content": "Streamed response"}, + "finish_reason": "length", + "index": 0, + } + ], + "model": "test_model", + "object": "chat.completion.chunk", + "created": 0, + } + start_time = time.time() + + def mock_parse_response(*args, **kwargs): + yield response_chunk + + mock_provider.aparse_response = mock_parse_response + mock_provider.join_chunks = MagicMock( + return_value=( + ChatCompletion( + id="id", + choices=[], + created=0, + model="test_model", + object="chat.completion", + ), + "Streamed response", + ) + ) + mock_provider.calculate_metrics = MagicMock(return_value={"input_tokens": 1}) + + response = [] + for chunk in mock_provider.handle_response( + request, mock_parse_response(), start_time + ): + response.append(chunk) + + assert isinstance(response[0], ChatCompletionChunk) + assert response[0].chat_output_stream == "Streamed response" + + +def test_handle_response_streaming_stop(mock_provider): + request = MagicMock( + is_stream=True, chat_input="Hello", model="test_model", parameters={} + ) + response_chunk = { + "choices": [ + { + "delta": {"content": "Streamed response"}, + "finish_reason": "stop", + "index": 0, + } + ], + "model": "test_model", + "object": "chat.completion.chunk", + "created": 0, + } + start_time = time.time() + + def mock_parse_response(*args, **kwargs): + yield response_chunk + + mock_provider.parse_response = mock_parse_response + mock_provider.join_chunks = MagicMock( + return_value=( + ChatCompletion( + id="id", + choices=[], + created=0, + model="test_model", + object="chat.completion", + ), + "Streamed response", + ) + ) + mock_provider.calculate_metrics = MagicMock(return_value={"input_tokens": 1}) + + response = [] + for chunk in mock_provider.handle_response( + request, mock_parse_response(), start_time + ): + response.append(chunk) + + assert isinstance(response[0], ChatCompletionChunk) + assert response[0].chat_output == "Streamed response"