Skip to content

Commit

Permalink
Feature initial response parsers (#165)
Browse files Browse the repository at this point in the history
* Initial implementation

* Add unit tests for yaml and json

* Update requirements.txt

* Update docstrings

* Add non_empty_validator() as an alternative validation method

* Replace non_empty_validator() with Field

* Add LLMFunctionWithPrompt to __init__
  • Loading branch information
Winston-503 authored Sep 26, 2024
1 parent 3ad8e47 commit d84ffb8
Show file tree
Hide file tree
Showing 8 changed files with 382 additions and 88 deletions.
1 change: 1 addition & 0 deletions council/llm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
ExecuteLLMRequest,
)
from .llm_function import LLMFunction, LLMFunctionError, FunctionOutOfRetryError
from .llm_function_with_prompt import LLMFunctionWithPrompt
from .monitored_llm import MonitoredLLM

from .chat_gpt_configuration import ChatGPTConfigurationBase
Expand Down
54 changes: 2 additions & 52 deletions council/llm/llm_function.py
Original file line number Diff line number Diff line change
@@ -1,61 +1,11 @@
from dataclasses import dataclass, is_dataclass
from typing import Any, Callable, Dict, Generic, Iterable, List, Optional, Sequence, Type, TypeVar, Union
from typing import Any, Generic, Iterable, List, Optional, Sequence, Union

from council.contexts import LLMContext

from ..utils import CodeParser
from .llm_answer import LLMParsingException
from .llm_base import LLMBase, LLMMessage
from .llm_middleware import LLMMiddleware, LLMMiddlewareChain, LLMRequest, LLMResponse

T_Response = TypeVar("T_Response")
LLMResponseParser = Callable[[LLMResponse], T_Response]

T_Dataclass = TypeVar("T_Dataclass")


def code_blocks_response_parser(cls: Type[T_Dataclass]) -> Type[T_Dataclass]:
"""
Decorator providing an automatic parsing of LLMResponse translating code blocks content into class fields.
Implement validate(self) to provide additional validation functionality.
"""
if not is_dataclass(cls):
cls = dataclass(cls)

def from_response(response: LLMResponse) -> T_Dataclass:
llm_response = response.value
parsed_blocks: Dict[str, Any] = {}

for field_name, field in cls.__dataclass_fields__.items(): # type: ignore
block = CodeParser.find_first(field_name, llm_response)
if block is None:
raise LLMParsingException(f"`{field_name}` block is not found")

field_type = field.type
value = block.code.strip()
if field_type is str:
parsed_blocks[field_name] = value
elif field_type is bool:
if value.lower() not in ["true", "false"]:
raise LLMParsingException(f"Cannot convert value `{value}` to bool for field `{field_name}`")
parsed_blocks[field_name] = value.lower() == "true"
elif field_type in [int, float]:
try:
parsed_blocks[field_name] = field_type(value)
except ValueError:
raise LLMParsingException(
f"Cannot convert value `{value}` to {field_type.__name__} for field `{field_name}`"
)
else:
raise ValueError(f"Unsupported type `{field_type.__name__}` for field `{field_name}`")

instance = cls(**parsed_blocks) # code blocks in LLM response template must match class fields
if hasattr(instance, "validate"):
instance.validate()
return instance

setattr(cls, "from_response", staticmethod(from_response))
return cls
from .llm_response_parser import LLMResponseParser, T_Response


class LLMFunctionError(Exception):
Expand Down
139 changes: 139 additions & 0 deletions council/llm/llm_response_parser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
import json
import re
from typing import Any, Callable, Dict, Type, TypeVar

import yaml
from pydantic import BaseModel, ValidationError

from ..utils import CodeParser
from .llm_answer import LLMParsingException
from .llm_middleware import LLMResponse

T_Response = TypeVar("T_Response")
LLMResponseParser = Callable[[LLMResponse], T_Response]

T = TypeVar("T", bound="BaseModelResponseParser")


class BaseModelResponseParser(BaseModel):
"""Base class for parsing LLM responses into structured data models"""

@classmethod
def from_response(cls: Type[T], response: LLMResponse) -> T:
"""
Parse an LLM response into a structured data model.
Must be implemented by subclasses to define specific parsing logic.
"""
raise NotImplementedError()

def validator(self) -> None:
"""
Implement custom validation logic for the parsed data.
Can be overridden by subclasses to add specific validation rules.
Raise LLMParsingException to trigger local correction.
Alternatively, use pydantic validation.
"""
pass

@classmethod
def create_and_validate(cls: Type[T], **kwargs) -> T:
instance = cls._try_create(**kwargs)
instance.validator()
return instance

@classmethod
def _try_create(cls: Type[T], **kwargs) -> T:
"""
Attempt to create a BaseModel object instance.
Raises an LLMParsingException if a ValidationError occurs during instantiation.
"""

try:
return cls(**kwargs)
except ValidationError as e:
# LLM-friendlier version of pydantic error message without "For further information visit..."
clean_exception_message = re.sub(r"For further information visit.*", "", str(e))
raise LLMParsingException(clean_exception_message)


class CodeBlocksResponseParser(BaseModelResponseParser):
"""Parser for responses containing multiple named code blocks"""

@classmethod
def from_response(cls: Type[T], response: LLMResponse) -> T:
llm_response = response.value
parsed_blocks: Dict[str, Any] = {}

for field_name in cls.model_fields.keys():
block = CodeParser.find_first(field_name, llm_response)
if block is None:
raise LLMParsingException(f"`{field_name}` block is not found")
parsed_blocks[field_name] = block.code.strip()

return cls.create_and_validate(**parsed_blocks)


class YAMLBlockResponseParser(BaseModelResponseParser):
"""Parser for responses containing a single YAML code block"""

@classmethod
def from_response(cls: Type[T], response: LLMResponse) -> T:
llm_response = response.value

yaml_block = CodeParser.find_first("yaml", llm_response)
if yaml_block is None:
raise LLMParsingException("yaml block is not found")

yaml_content = YAMLResponseParser.parse(yaml_block.code)
return cls.create_and_validate(**yaml_content)


class YAMLResponseParser(BaseModelResponseParser):
"""Parser for responses containing raw YAML content"""

@classmethod
def from_response(cls: Type[T], response: LLMResponse) -> T:
llm_response = response.value

yaml_content = YAMLResponseParser.parse(llm_response)
return cls.create_and_validate(**yaml_content)

@staticmethod
def parse(content: str) -> Dict[str, Any]:
try:
return yaml.safe_load(content)
except yaml.YAMLError as e:
raise LLMParsingException(f"Error while parsing yaml: {e}")


class JSONBlockResponseParser(BaseModelResponseParser):
"""Parser for responses containing a single JSON code block"""

@classmethod
def from_response(cls: Type[T], response: LLMResponse) -> T:
llm_response = response.value

json_block = CodeParser.find_first("json", llm_response)
if json_block is None:
raise LLMParsingException("json block is not found")

json_content = JSONResponseParser.parse(json_block.code)
return cls.create_and_validate(**json_content)


class JSONResponseParser(BaseModelResponseParser):
"""Parser for responses containing raw JSON content"""

@classmethod
def from_response(cls: Type[T], response: LLMResponse) -> T:
llm_response = response.value

json_content = JSONResponseParser.parse(llm_response)
return cls.create_and_validate(**json_content)

@staticmethod
def parse(content: str) -> Dict[str, Any]:
try:
return json.loads(content)
except json.JSONDecodeError as e:
raise LLMParsingException(f"Error while parsing json: {e}")
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,5 @@ GoogleNews>=1.6.10
pymediawiki~=0.7.3
beautifulsoup4~=4.12.2

# Response Parsers
pydantic==2.8.*
8 changes: 4 additions & 4 deletions tests/integration/llm/test_llm_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@

from council import AzureLLM
from council.llm import LLMParsingException, LLMMessage
from council.llm.llm_function import LLMFunction, code_blocks_response_parser
from council.llm.llm_function import LLMFunction
from council.llm.llm_response_parser import CodeBlocksResponseParser
from council.prompt import LLMPromptConfigObject
from tests import get_data_filename
from tests.unit import LLMPrompts
Expand All @@ -14,13 +15,12 @@
USER = prompt_config.get_user_prompt_template("default")


@code_blocks_response_parser
class SQLResult:
class SQLResult(CodeBlocksResponseParser):
solved: bool
explanation: str
sql: str

def validate(self) -> None:
def validator(self) -> None:
if "limit" not in self.sql.lower():
raise LLMParsingException("Generated SQL query should contain a LIMIT clause")

Expand Down
2 changes: 1 addition & 1 deletion tests/integration/llm/test_llm_function_with_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import dotenv

from council import AzureLLM
from council.llm.llm_function_with_prompt import LLMFunctionWithPrompt
from council.llm import LLMFunctionWithPrompt
from council.prompt import LLMPromptConfigObject
from tests import get_data_filename
from tests.integration.llm.test_llm_function import SQLResult
Expand Down
Original file line number Diff line number Diff line change
@@ -1,31 +1,31 @@
import unittest

from pydantic import field_validator, Field

from council.llm import LLMParsingException
from council.llm.llm_function import (
code_blocks_response_parser,
LLMFunction,
FunctionOutOfRetryError,
)
from council.llm.llm_function import LLMFunction, FunctionOutOfRetryError
from council.llm.llm_response_parser import CodeBlocksResponseParser
from council.mocks import MockLLM, MockMultipleResponses


@code_blocks_response_parser
class Response:
text: str
class Response(CodeBlocksResponseParser):
text: str = Field(..., min_length=1)
flag: bool
age: int
number: float

def validate(self) -> None:
@field_validator("text")
@classmethod
def n(cls, text: str) -> str:
if text == "incorrect":
raise ValueError(f"Incorrect `text` value: `{text}`")
return text

def validator(self) -> None:
if self.age < 0:
raise LLMParsingException(f"Age must be a positive number; got `{self.age}`")


@code_blocks_response_parser
class BadResponse:
complex_type: Response


def format_response(text: str, flag: str, age: str, number: str) -> str:
return f"""
```text
Expand Down Expand Up @@ -65,41 +65,61 @@ def test_wrong_bool(self):
with self.assertRaises(FunctionOutOfRetryError) as e:
_ = execute_mock_llm_func(llm, Response.from_response)

assert str(e.exception).strip().endswith("Cannot convert value `not-a-bool` to bool for field `flag`")
self.assertIn(
"Input should be a valid boolean, unable to interpret input "
"[type=bool_parsing, input_value='not-a-bool', input_type=str]",
str(e.exception),
)

def test_wrong_int(self):
llm = MockLLM.from_response(format_response(text="Some text", flag="true", age="not-an-int", number="3.14"))
with self.assertRaises(FunctionOutOfRetryError) as e:
_ = execute_mock_llm_func(llm, Response.from_response)

assert str(e.exception).strip().endswith("Cannot convert value `not-an-int` to int for field `age`")
self.assertIn(
"Input should be a valid integer, unable to parse string as an integer "
"[type=int_parsing, input_value='not-an-int', input_type=str]",
str(e.exception),
)

def test_validate_int(self):
llm = MockLLM.from_response(format_response(text="Some text", flag="true", age="-5", number="3.14"))
def test_wrong_float(self):
llm = MockLLM.from_response(format_response(text="Some text", flag="true", age="34", number="not-a-float"))
with self.assertRaises(FunctionOutOfRetryError) as e:
_ = execute_mock_llm_func(llm, Response.from_response)

assert str(e.exception).strip().endswith("Age must be a positive number; got `-5`")
self.assertIn(
"Input should be a valid number, unable to parse string as a number "
"[type=float_parsing, input_value='not-a-float', input_type=str]",
str(e.exception),
)

def test_wrong_float(self):
llm = MockLLM.from_response(format_response(text="Some text", flag="true", age="34", number="not-a-float"))
def test_pydentic_validation(self):
llm = MockLLM.from_response(format_response(text="incorrect", flag="true", age="34", number="3.14"))
with self.assertRaises(FunctionOutOfRetryError) as e:
_ = execute_mock_llm_func(llm, Response.from_response)

assert str(e.exception).strip().endswith("Cannot convert value `not-a-float` to float for field `number`")
self.assertIn(
"Value error, Incorrect `text` value: `incorrect` "
"[type=value_error, input_value='incorrect', input_type=str]",
str(e.exception),
)

def test_wrong_type(self):
llm = MockLLM.from_response(
"""
```complex_type
Some text
```
"""
def test_non_empty_validator(self):
llm = MockLLM.from_response(format_response(text=" ", flag="true", age="34", number="3.14"))
with self.assertRaises(FunctionOutOfRetryError) as e:
_ = execute_mock_llm_func(llm, Response.from_response)

self.assertIn(
"String should have at least 1 character [type=string_too_short, input_value='', input_type=str]",
str(e.exception),
)

def test_custom_validation(self):
llm = MockLLM.from_response(format_response(text="Some text", flag="true", age="-5", number="3.14"))
with self.assertRaises(FunctionOutOfRetryError) as e:
_ = execute_mock_llm_func(llm, BadResponse.from_response)
_ = execute_mock_llm_func(llm, Response.from_response)

assert str(e.exception).strip().endswith("Unsupported type `Response` for field `complex_type`")
assert str(e.exception).strip().endswith("Age must be a positive number; got `-5`")

def test_correct(self):
llm = MockLLM.from_response(format_response(text="Some text", flag="true", age="34", number="3.14"))
Expand Down
Loading

0 comments on commit d84ffb8

Please sign in to comment.