diff --git a/allms/domain/response.py b/allms/domain/response.py index 0dcf0dc..d4c3f09 100644 --- a/allms/domain/response.py +++ b/allms/domain/response.py @@ -5,13 +5,18 @@ from allms.domain.input_data import InputData +class ResponseParsingOutput(BaseModel): + response: typing.Optional[typing.Any] + error_message: typing.Optional[str] + + class ResponseData(BaseModel): response: typing.Optional[typing.Any] = None input_data: typing.Optional[InputData] = None number_of_prompt_tokens: typing.Optional[int] = None number_of_generated_tokens: typing.Optional[int] = None - error: typing.Optional[typing.Union[str, Exception]] = None + error: typing.Optional[str] = None # Without this, only classes inheriting from the pydantic BaseModel are allowed as field types. Exception isn't # such a class and that's why we need it. diff --git a/allms/models/abstract.py b/allms/models/abstract.py index 4016baa..882c1ee 100644 --- a/allms/models/abstract.py +++ b/allms/models/abstract.py @@ -13,7 +13,6 @@ from langchain.chat_models.base import BaseChatModel from langchain.output_parsers import PydanticOutputParser from langchain.prompts import ChatPromptTemplate -from langchain.schema import OutputParserException from langchain_core.language_models.llms import create_base_retry_decorator from langchain_core.prompts import HumanMessagePromptTemplate, SystemMessagePromptTemplate from langchain_core.prompts.prompt import PromptTemplate @@ -34,6 +33,7 @@ from allms.domain.prompt_dto import SummaryOutputClass, KeywordsOutputClass from allms.domain.response import ResponseData from allms.utils.long_text_processing_utils import get_max_allowed_number_of_tokens +from allms.utils.response_parsing_utils import ResponseParser logger = logging.getLogger(__name__) @@ -58,6 +58,8 @@ def __init__( self._is_long_text_bypass_enabled: bool = False # Should be false till we fully implement support for long sequences in our package self._aggregation_strategy: AggregationLogicForLongInputData = AggregationLogicForLongInputData.SIMPLE_CONCATENATION self._parser: typing.Optional[PydanticOutputParser] = None + self._json_pattern = re.compile(r"{.*?}", re.DOTALL) + self._is_json_format_injected_into_prompt: bool = True if max_output_tokens >= model_total_max_tokens: raise ValueError("max_output_tokens has to be lower than model_total_max_tokens") @@ -103,38 +105,9 @@ def generate( ) if output_data_model_class: - return self._parse_model_output(model_responses) + return ResponseParser(self._parser).parse_model_output(model_responses) return model_responses - def _parse_response(self, model_response_data: ResponseData) -> typing.Tuple[str, typing.Optional[str]]: - try: - return self._parser.parse(model_response_data.response), None - except OutputParserException as output_parser_exception: - return None, OutputParserException( - f"An OutputParserException has occurred for " - f"The response from model: {model_response_data.response}\n" - f"The exception message: {output_parser_exception}" - ) - - def _parse_model_output(self, model_responses_data: typing.List[ResponseData]) -> typing.List[ResponseData]: - parsed_responses = [] - for model_response_data in model_responses_data: - if not model_response_data.error: - response, error_message = self._parse_response(model_response_data) - - parsed_responses.append(ResponseData( - input_data=model_response_data.input_data, - response=response, - error=error_message, - number_of_prompt_tokens=model_response_data.number_of_prompt_tokens, - number_of_generated_tokens=model_response_data.number_of_generated_tokens - - )) - else: - parsed_responses.append(model_response_data) - - return parsed_responses - async def _generate( self, prompt: str, @@ -155,10 +128,12 @@ async def _generate( if output_data_model_class: self._parser = PydanticOutputParser(pydantic_object=output_data_model_class) - prompt_template_args[PromptConstants.PARTIAL_VARIABLES_STR] = { - PromptConstants.OUTPUT_DATA_MODEL: self._parser.get_format_instructions(), - } - prompt_template_args[PromptConstants.TEMPLATE_STR] = self._add_output_data_format(prompt=prompt) + + if self._is_json_format_injected_into_prompt: + prompt_template_args[PromptConstants.PARTIAL_VARIABLES_STR] = { + PromptConstants.OUTPUT_DATA_MODEL: self._parser.get_format_instructions(), + } + prompt_template_args[PromptConstants.TEMPLATE_STR] = self._add_output_data_format(prompt=prompt) chat_prompts = await self._build_chat_prompts(prompt_template_args, system_prompt) diff --git a/allms/models/azure_llama2.py b/allms/models/azure_llama2.py index dc9d514..67e5250 100644 --- a/allms/models/azure_llama2.py +++ b/allms/models/azure_llama2.py @@ -1,11 +1,15 @@ import typing from asyncio import AbstractEventLoop +from typing import List, Type from langchain_community.chat_models.azureml_endpoint import LlamaChatContentFormatter +from pydantic import BaseModel from allms.defaults.azure_defaults import AzureLlama2Defaults from allms.defaults.general_defaults import GeneralDefaults from allms.domain.configuration import AzureSelfDeployedConfiguration +from allms.domain.input_data import InputData +from allms.domain.response import ResponseData from allms.models.abstract import AbstractModel from allms.models.azure_base import AzureMLOnlineEndpointAsync @@ -35,6 +39,8 @@ def __init__( event_loop=event_loop ) + self._is_json_format_injected_into_prompt = False + def _create_llm(self) -> AzureMLOnlineEndpointAsync: model_kwargs = {"max_new_tokens": self._max_output_tokens, "top_p": self._top_p, "do_sample": False} if self._temperature > 0: diff --git a/allms/models/azure_mistral.py b/allms/models/azure_mistral.py index 0b6fe1b..d9e7dd6 100644 --- a/allms/models/azure_mistral.py +++ b/allms/models/azure_mistral.py @@ -35,6 +35,8 @@ def __init__( event_loop=event_loop ) + self._is_json_format_injected_into_prompt = False + def _create_llm(self) -> AzureMLOnlineEndpointAsync: model_kwargs = { "max_new_tokens": self._max_output_tokens, "top_p": self._top_p, "do_sample": False, diff --git a/allms/models/vertexai_gemma.py b/allms/models/vertexai_gemma.py index c4cb491..eb725c4 100644 --- a/allms/models/vertexai_gemma.py +++ b/allms/models/vertexai_gemma.py @@ -38,6 +38,8 @@ def __init__( event_loop=event_loop ) + self._is_json_format_injected_into_prompt = False + def _create_llm(self) -> VertexAIModelGarden: return VertexAIModelGardenWrapper( model_name=GemmaModelDefaults.GCP_MODEL_NAME, diff --git a/allms/utils/response_parsing_utils.py b/allms/utils/response_parsing_utils.py new file mode 100644 index 0000000..4fc8fff --- /dev/null +++ b/allms/utils/response_parsing_utils.py @@ -0,0 +1,70 @@ +import re +import typing + +from langchain.output_parsers import PydanticOutputParser +from langchain.schema import OutputParserException + +from allms.domain.response import ResponseData, ResponseParsingOutput + + +class ResponseParser: + def __init__(self, parser: PydanticOutputParser) -> None: + self._json_pattern = re.compile(r"{.*?}", re.DOTALL) + self._parser = parser + + def _clean_extracted_json(self, extracted_json: str) -> str: + json_without_newlines = extracted_json.replace("\\n", "") + json_without_backslashes = json_without_newlines.replace("\\", "") + + return json_without_backslashes + + def _extract_json_from_response(self, model_response_data: ResponseData) -> str: + search_results = self._json_pattern.findall(model_response_data.response) + + if len(search_results) == 0: + return model_response_data.response + + return self._clean_extracted_json(search_results[0]) + + def _parse_response( + self, + model_response_data: ResponseData + ) -> ResponseParsingOutput: + raw_response = self._extract_json_from_response(model_response_data) + + try: + return ResponseParsingOutput( + response=self._parser.parse(raw_response), + error_message=None + ) + except OutputParserException as output_parser_exception: + return ResponseParsingOutput( + response=None, + error_message=f""" + An OutputParserException has occurred for the model response: {raw_response} + The exception message: {output_parser_exception} + """ + ) + + def parse_model_output( + self, + model_responses_data: typing.List[ResponseData] + ) -> typing.List[ResponseData]: + parsed_responses = [] + + for model_response_data in model_responses_data: + if not model_response_data.error: + response_with_error = self._parse_response(model_response_data) + + parsed_responses.append(ResponseData( + input_data=model_response_data.input_data, + response=response_with_error.response, + error=response_with_error.error_message, + number_of_prompt_tokens=model_response_data.number_of_prompt_tokens, + number_of_generated_tokens=model_response_data.number_of_generated_tokens + + )) + else: + parsed_responses.append(model_response_data) + + return parsed_responses \ No newline at end of file diff --git a/docs/api/models/azure_llama2_model.md b/docs/api/models/azure_llama2_model.md index 469c248..54d5d41 100644 --- a/docs/api/models/azure_llama2_model.md +++ b/docs/api/models/azure_llama2_model.md @@ -39,8 +39,7 @@ generate( - `input_data` (`Optional[List[InputData]]`): If prompt contains symbolic variables you can use this parameter to generate model responses for batch of examples. Each symbolic variable from the prompt should have mapping provided in the `input_mappings` of `InputData`. -- `output_data_model_class` (`Optional[Type[BaseModel]]`): If provided forces the model to generate output in the - format defined by the passed class. Generated response is automatically parsed to this class. +- `output_data_model_class` (`Optional[Type[BaseModel]]`): Generated response is automatically parsed to this class. WARNING: You need to manually provide the JSON format instructions in the prompt, they are not injected for this model. #### Returns `List[ResponseData]`: Each `ResponseData` contains the response for a single example from `input_data`. If `input_data` diff --git a/docs/api/models/azure_mistral_model.md b/docs/api/models/azure_mistral_model.md index 19d90f1..b4adf86 100644 --- a/docs/api/models/azure_mistral_model.md +++ b/docs/api/models/azure_mistral_model.md @@ -37,8 +37,7 @@ generate( - `input_data` (`Optional[List[InputData]]`): If prompt contains symbolic variables you can use this parameter to generate model responses for batch of examples. Each symbolic variable from the prompt should have mapping provided in the `input_mappings` of `InputData`. -- `output_data_model_class` (`Optional[Type[BaseModel]]`): If provided forces the model to generate output in the - format defined by the passed class. Generated response is automatically parsed to this class. +- `output_data_model_class` (`Optional[Type[BaseModel]]`): Generated response is automatically parsed to this class. WARNING: You need to manually provide the JSON format instructions in the prompt, they are not injected for this model. Note that Mistral-based models currently don't support system prompts. diff --git a/docs/api/models/vertexai_gemma.md b/docs/api/models/vertexai_gemma.md index e20bcdb..bf298e4 100644 --- a/docs/api/models/vertexai_gemma.md +++ b/docs/api/models/vertexai_gemma.md @@ -44,8 +44,7 @@ generate( - `input_data` (`Optional[List[InputData]]`): If prompt contains symbolic variables you can use this parameter to generate model responses for batch of examples. Each symbolic variable from the prompt should have mapping provided in the `input_mappings` of `InputData`. -- `output_data_model_class` (`Optional[Type[BaseModel]]`): If provided forces the model to generate output in the - format defined by the passed class. Generated response is automatically parsed to this class. +- `output_data_model_class` (`Optional[Type[BaseModel]]`): Generated response is automatically parsed to this class. WARNING: You need to manually provide the JSON format instructions in the prompt, they are not injected for this model. #### Returns `List[ResponseData]`: Each `ResponseData` contains the response for a single example from `input_data`. If `input_data` diff --git a/docs/usage/forcing_response_format.md b/docs/usage/forcing_response_format.md index b348c05..72f1f3e 100644 --- a/docs/usage/forcing_response_format.md +++ b/docs/usage/forcing_response_format.md @@ -66,13 +66,10 @@ False ## What to do when output formatting doesn't work? -The feature described above works best with advanced proprietary models like GPT and PaLM/Gemini. Less capable models like Llama2 or Mistral -may not able to understand instructions passed as output_dataclasses, and in most cases the returned response won't be compatible -with the defined format, resulting in an unexpected response. +The feature described above works only with advanced proprietary models like GPT and PaLM/Gemini. Less capable models like Llama2 or Mistral +are unable to understand instructions passed as output_dataclasses. -In such cases, we recommend to address the issue by specifying in the prompt how the response should look like. Using -few-shot learning techniques is also advisable. In the case of JSON-like output, use double curly brackets to escape them in order -to use them in the JSON example. +For these less capable models, you need to manually specify in the prompt how the response should look like. You can then pass the `output_data_model_class` to try parsing the output. Using few-shot learning techniques is also advisable. In the case of JSON-like output, use double curly brackets instead of single ones, e.g. `{{"key": "value"}}` instead of `{"key": "value"}`. ## How forcing response format works under the hood? To force the model to provide output in a desired format, under the hood `allms` automatically adds a description @@ -90,7 +87,7 @@ Here is the output schema: ``` ```` -This feature is really helpful, but you have to bear in mind that by using it you increase the number or prompt tokens +This feature is really helpful, but you have to keep in mind that by using it you increase the number or prompt tokens so it'll make the requests more costly (if you're using model with per token pricing) If the model will return an output that doesn't comform to the defined data model, raw model response will be returned diff --git a/poetry.lock b/poetry.lock index bfb9a58..5dd80be 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.5.0 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.5.1 and should not be changed by hand. [[package]] name = "aiohttp" @@ -2167,6 +2167,23 @@ tomli = {version = ">=1.0.0", markers = "python_version < \"3.11\""} [package.extras] testing = ["argcomplete", "attrs (>=19.2.0)", "hypothesis (>=3.56)", "mock", "nose", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"] +[[package]] +name = "pytest-mock" +version = "3.14.0" +description = "Thin-wrapper around the mock package for easier use with pytest" +optional = false +python-versions = ">=3.8" +files = [ + {file = "pytest-mock-3.14.0.tar.gz", hash = "sha256:2719255a1efeceadbc056d6bf3df3d1c5015530fb40cf347c0f9afac88410bd0"}, + {file = "pytest_mock-3.14.0-py3-none-any.whl", hash = "sha256:0b72c38033392a5f4621342fe11e9219ac11ec9d375f8e2a0c164539e0d70f6f"}, +] + +[package.dependencies] +pytest = ">=6.2.5" + +[package.extras] +dev = ["pre-commit", "pytest-asyncio", "tox"] + [[package]] name = "python-dateutil" version = "2.8.2" @@ -3405,4 +3422,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "2a63aeb94ee8c2072bdc003d9acc29b224e28d3589908645a0d8e88ca703b8ba" +content-hash = "7915b11fb574bdc236e238e884d7cbc0e43849b0750577f6d1aedd69b00162f6" diff --git a/pyproject.toml b/pyproject.toml index 7cef706..7804bd0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "allms" -version = "1.0.1" +version = "1.0.2" description = "" authors = ["Allegro Opensource "] readme = "README.md" @@ -17,6 +17,7 @@ langchain = "^0.0.351" aioresponses = "^0.7.6" tiktoken = "^0.6.0" openai = "^0.27.8" +pytest-mock = "^3.14.0" [tool.poetry.group.dev.dependencies] pytest = "^7.4.0" diff --git a/tests/conftest.py b/tests/conftest.py index 0843bb7..e28fea9 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -27,7 +27,7 @@ class GenerativeModels: vertex_palm: typing.Optional[VertexAIPalmModel] = None -class VertexAIMock(FakeListLLM): +class ModelWithoutAsyncRequestsMock(FakeListLLM): def __init__(self, *args, **kwargs): super().__init__(responses=["{}"]) @@ -37,9 +37,11 @@ def models(): event_loop = asyncio.new_event_loop() with ( - patch("allms.models.vertexai_palm.CustomVertexAI", VertexAIMock), - patch("allms.models.vertexai_gemini.CustomVertexAI", VertexAIMock), - patch("allms.models.vertexai_gemma.VertexAIModelGardenWrapper", VertexAIMock) + patch("allms.models.vertexai_palm.CustomVertexAI", ModelWithoutAsyncRequestsMock), + patch("allms.models.vertexai_gemini.CustomVertexAI", ModelWithoutAsyncRequestsMock), + patch("allms.models.vertexai_gemma.VertexAIModelGardenWrapper", ModelWithoutAsyncRequestsMock), + patch("allms.models.azure_llama2.AzureMLOnlineEndpointAsync", ModelWithoutAsyncRequestsMock), + patch("allms.models.azure_mistral.AzureMLOnlineEndpointAsync", ModelWithoutAsyncRequestsMock) ): return { "azure_open_ai": AzureOpenAIModel( diff --git a/tests/test_end_to_end.py b/tests/test_end_to_end.py index 1950282..24cf19e 100644 --- a/tests/test_end_to_end.py +++ b/tests/test_end_to_end.py @@ -1,5 +1,8 @@ import re +from unittest.mock import patch +from langchain.prompts import ChatPromptTemplate, HumanMessagePromptTemplate, PromptTemplate, SystemMessagePromptTemplate + from allms.constants.input_data import IODataConstants from allms.domain.prompt_dto import KeywordsOutputClass from allms.utils import io_utils @@ -66,6 +69,82 @@ def test_model_is_queried_successfully( assert list(map(lambda output: int(output[IODataConstants.GENERATED_TOKENS_NUMBER]), expected_output)) == list( map(lambda example: example.number_of_generated_tokens, parsed_responses)) + + + def test_prompt_is_not_modified_for_open_source_models(self, mock_aioresponse, models, mocker): + # GIVEN + open_source_models = ["azure_llama2", "azure_mistral", "vertex_gemma"] + + mock_aioresponse.post( + url=re.compile(f"^https:\/\/dummy-endpoint.*$"), + payload={ + "choices": [{ + "message": { + "content": "{\"keywords\": [\"Indywidualna racja żywnościowa\", \"wojskowa\", \"S-R-9\", \"set nr 9\", \"Makaron po bolońsku\", \"Konserwa tyrolska\", \"Suchary\", \"Koncentrat napoju herbacianego instant o smaku owoców leśnych\", \"Dżem malinowy\", \"Baton zbożowo-owocowy o smaku figowym\"]}", + "role": "" + } + }], + "usage": {} + }, + repeat=True + ) + + input_data = io_utils.load_csv_to_input_data( + limit=5, + path="./tests/resources/test_input_data.csv" + ) + prompt_template_text = """Extract at most 10 keywords that could be used as features in a search index from this Polish product description. + + {text} + """ + prompt_template_spy = mocker.spy(ChatPromptTemplate, "from_messages") + + # WHEN & THEN + for model_name, model in models.items(): + model.generate( + prompt=prompt_template_text, + input_data=input_data, + output_data_model_class=KeywordsOutputClass, + system_prompt=None if model_name == "azure_mistral" else "This is a system prompt." + ) + + if model_name in open_source_models: + messages = [ + HumanMessagePromptTemplate( + prompt=PromptTemplate( + input_variables=["text"], + template=prompt_template_text + ) + ) + ] + if model_name != "azure_mistral": + messages = [ + SystemMessagePromptTemplate( + prompt=PromptTemplate( + input_variables=[], + template="This is a system prompt." + ) + ) + ] + messages + prompt_template_spy.assert_called_with(messages) + else: + prompt_template_spy.assert_called_with([ + SystemMessagePromptTemplate( + prompt=PromptTemplate( + input_variables=[], + template="This is a system prompt." + ) + ), + HumanMessagePromptTemplate( + prompt=PromptTemplate( + input_variables=["text"], + partial_variables={ + 'output_data_model': 'The output should be formatted as a JSON instance that conforms to the JSON schema below.\n\nAs an example, for the schema {"properties": {"foo": {"title": "Foo", "description": "a list of strings", "type": "array", "items": {"type": "string"}}}, "required": ["foo"]}\nthe object {"foo": ["bar", "baz"]} is a well-formatted instance of the schema. The object {"properties": {"foo": ["bar", "baz"]}} is not well-formatted.\n\nHere is the output schema:\n```\n{"properties": {"keywords": {"title": "Keywords", "description": "List of keywords", "type": "array", "items": {"type": "string"}}}, "required": ["keywords"]}\n```' + }, + template=f"{prompt_template_text}\n\n{{output_data_model}}" + ) + ) + ]) def test_model_times_out( self, diff --git a/tests/test_output_parser.py b/tests/test_output_parser.py index 5e02556..b32179c 100644 --- a/tests/test_output_parser.py +++ b/tests/test_output_parser.py @@ -2,6 +2,7 @@ from unittest.mock import patch from langchain.schema import OutputParserException +import pytest from allms.domain.input_data import InputData from allms.domain.prompt_dto import SummaryOutputClass, KeywordsOutputClass @@ -41,7 +42,44 @@ def test_output_parser_returns_error_when_model_output_returns_different_field(s # WHEN & THEN for model in models.values(): model_response = model.generate(prompt, input_data, SummaryOutputClass) - assert type(model_response[0].error) == OutputParserException + assert "OutputParserException" in model_response[0].error + assert model_response[0].response is None + + @patch("langchain.chains.base.Chain.arun") + @patch("langchain_community.llms.vertexai.VertexAI.get_num_tokens") + @pytest.mark.parametrize("json_response", [ + ("{\"summary\": \"This is the model output\"}"), + ("Sure! Here's the JSON you wanted: {\"summary\": \"This is the model output\"} Have a nice day!"), + ("<>\\n{\\n \"summary\": \"This is the model output\"\\n}\\n<>"), + ("{\\\"summary\\\": \\\"This is the model output\\\"}\\n}") + ]) + def test_output_parser_extracts_json_from_response(self, tokens_mock, chain_run_mock, models, json_response): + # GIVEN + chain_run_mock.return_value = json_response + tokens_mock.return_value = 1 + + input_data = [InputData(input_mappings={"text": "Some dummy text"}, id="1")] + prompt = "Some Dummy Prompt {text}" + + # WHEN & THEN + for model in models.values(): + model_response = model.generate(prompt, input_data, SummaryOutputClass) + assert model_response[0].response == SummaryOutputClass(summary="This is the model output") + + @patch("langchain.chains.base.Chain.arun") + @patch("langchain_community.llms.vertexai.VertexAI.get_num_tokens") + def test_output_parser_returns_error_when_json_is_garbled(self, tokens_mock, chain_run_mock, models): + # GIVEN + chain_run_mock.return_value = "Sure! Here's the JSON you wanted: {\"summary: \"text\"}" + tokens_mock.return_value = 1 + + input_data = [InputData(input_mappings={"text": "Some dummy text"}, id="1")] + prompt = "Some Dummy Prompt {text}" + + # WHEN & THEN + for model in models.values(): + model_response = model.generate(prompt, input_data, SummaryOutputClass) + assert "OutputParserException" in model_response[0].error assert model_response[0].response is None @patch("langchain.chains.base.Chain.arun") @@ -94,4 +132,4 @@ def test_model_output_when_input_data_is_empty(self, tokens_mock, chain_run_mock for model in models.values(): model_response = model.generate(prompt, None, KeywordsOutputClass) assert model_response[0].response is None - assert type(model_response[0].error) == OutputParserException + assert "OutputParserException" in model_response[0].error \ No newline at end of file