Skip to content

Commit

Permalink
Add LLMFunction (#147)
Browse files Browse the repository at this point in the history
* AddLLMFunction

* Update tests/integration/llm/test_llm_function.py

Co-authored-by: Nikolaiev Dmytro <[email protected]>

* Update tests/integration/llm/test_llm_function.py

Co-authored-by: Nikolaiev Dmytro <[email protected]>

* Fix review comments

---------

Co-authored-by: Nikolaiev Dmytro <[email protected]>
  • Loading branch information
aflament and Winston-503 authored Jun 21, 2024
1 parent 0904645 commit ed1fb99
Show file tree
Hide file tree
Showing 7 changed files with 387 additions and 5 deletions.
4 changes: 3 additions & 1 deletion council/llm/llm_answer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@


class LLMParsingException(Exception):
pass
def __init__(self, message: str = "Your response is not correctly formatted.") -> None:
super().__init__(message)
self.message = message


class llm_property(property):
Expand Down
4 changes: 4 additions & 0 deletions council/llm/llm_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,10 @@ def __init__(
def configuration(self) -> T_Configuration:
return self._configuration

@property
def model_name(self) -> str:
return self.configuration.model_name()

def post_chat_request(self, context: LLMContext, messages: Sequence[LLMMessage], **kwargs: Any) -> LLMResult:
"""
Sends a chat request to the language model.
Expand Down
21 changes: 17 additions & 4 deletions council/llm/llm_exception.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional
from typing import Optional, Sequence


class LLMException(Exception):
Expand All @@ -17,9 +17,8 @@ def __init__(self, message: str, llm_name: Optional[str]) -> None:
Returns:
None
"""
super().__init__(
f"llm:{llm_name}, message {message}" if llm_name is not None and len(llm_name) > 0 else message
)
self.message = f"llm:{llm_name}, message {message}" if llm_name is not None and len(llm_name) > 0 else message
super().__init__(self.message)


class LLMCallTimeoutException(LLMException):
Expand Down Expand Up @@ -89,3 +88,17 @@ def __init__(self, token_count: int, limit: int, model: str, llm_name: Optional[
None
"""
super().__init__(f"token_count={token_count} is exceeding model {model} limit of {limit} tokens.", llm_name)


class LLMOutOfRetriesException(LLMException):
"""
Custom exception raised when the maximum number of retries is reached.
"""

def __init__(
self, llm_name: Optional[str], retry_count: int, exceptions: Optional[Sequence[Exception]] = None
) -> None:
"""
Initializes an instance of LLMOutOfRetriesException.
"""
super().__init__(f"Exceeded maximum retries after {retry_count} attempts", llm_name)
95 changes: 95 additions & 0 deletions council/llm/llm_function.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
from typing import Any, Callable, Generic, List, Optional, Sequence, TypeVar

from council import LLMContext
from council.llm import LLMBase, LLMMessage, LLMParsingException
from council.llm.llm_middleware import LLMMiddlewareChain, LLMRequest

T_Response = TypeVar("T_Response")

LLMResponseParser = Callable[[str], T_Response]


class LLMFunctionError(Exception):
"""
Exception raised when an error occurs during the execution of a function.
"""

def __init__(self, message: str, retryable: bool = False):
"""
Initialize the FunctionError instance.
"""
super().__init__(message)
self.message = message
self.retryable = retryable


class FunctionOutOfRetryError(LLMFunctionError):
"""
Exception raised when the maximum number of function execution retries is reached.
Stores all previous exceptions raised during retry attempts.
"""

def __init__(self, retry_count: int, exceptions: Optional[Sequence[Exception]] = None):
"""
Initialize the FunctionOutOfRetryException instance.
Args:
retry_count (int): The number of retries attempted.
exceptions (List[Exception]): List of exceptions raised during retry attempts.
"""
super().__init__(f"Exceeded maximum retries after {retry_count} attempts")
self.exceptions = exceptions if exceptions is not None else []

def __str__(self) -> str:
message = super().__str__()
if self.exceptions:
message += "\nPrevious exceptions:\n"
for i, exception in enumerate(self.exceptions, start=1):
message += f"{i}. {exception}\n"
return message


class LLMFunction(Generic[T_Response]):
def __init__(
self, llm: LLMBase, response_parser: LLMResponseParser, system_message: str, max_retries: int = 3
) -> None:
self._llm_middleware = LLMMiddlewareChain(llm)
self._system_message = LLMMessage.system_message(system_message)
self._response_parser = response_parser
self._max_retries = max_retries
self._context = LLMContext.empty()

def execute(self, user_message: str, **kwargs: Any) -> T_Response:
messages = [self._system_message, LLMMessage.user_message(user_message)]
new_messages: List[LLMMessage] = []
exceptions: List[Exception] = []

retry = 0
while retry <= self._max_retries:
messages = messages + new_messages
request = LLMRequest(context=self._context, messages=messages, **kwargs)
try:
llm_response = self._llm_middleware.execute(request)
if llm_response.result is not None:
response = llm_response.result.first_choice
return self._response_parser(response)
except LLMParsingException as e:
exceptions.append(e)
new_messages = self._handle_error(e, response, e.message)
except LLMFunctionError as e:
exceptions.append(e)
if not e.retryable:
raise e
new_messages = self._handle_error(e, response, e.message)
except Exception as e:
exceptions.append(e)
new_messages = self._handle_error(e, response, f"Fix the following exception: `{e}`")

retry += 1

raise FunctionOutOfRetryError(self._max_retries, exceptions)

def _handle_error(self, e: Exception, response: str, user_message: str) -> List[LLMMessage]:
error = f"{e.__class__.__name__}: `{e}`"
self._context.logger.warning(f"Exception occurred: {error} for response {response}")
return [LLMMessage.assistant_message(response), LLMMessage.user_message(f"{user_message} Fix\n{error}")]
121 changes: 121 additions & 0 deletions council/llm/llm_middleware.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
from __future__ import annotations

import time
from typing import Any, Callable, List, Optional, Protocol, Sequence

from council import LLMContext

from .llm_base import LLMBase, LLMMessage, LLMResult
from .llm_exception import LLMOutOfRetriesException


class LLMRequest:
def __init__(self, context: LLMContext, messages: Sequence[LLMMessage], **kwargs: Any):
self._context = context
self._messages = messages
self._kwargs = kwargs

@property
def context(self) -> LLMContext:
return self._context

@property
def messages(self) -> Sequence[LLMMessage]:
return self._messages

@property
def kwargs(self) -> Any:
return self._kwargs

@staticmethod
def default(messages: Sequence[LLMMessage], **kwargs: Any) -> LLMRequest:
return LLMRequest(LLMContext.empty(), messages, **kwargs)


class LLMResponse:
def __init__(self, request: LLMRequest, result: Optional[LLMResult], duration: float):
self._request = request
self._result = result
self._duration = duration

@property
def result(self) -> Optional[LLMResult]:
return self._result

@property
def duration(self) -> float:
return self._duration

@staticmethod
def empty(request: LLMRequest) -> LLMResponse:
return LLMResponse(request, None, -1.0)


ExecuteLLMRequest = Callable[[LLMRequest], LLMResponse]


class LLMMiddleware(Protocol):
def __call__(self, llm: LLMBase, execute: ExecuteLLMRequest, request: LLMRequest) -> LLMResponse: ...


class LLMMiddlewareChain:
def __init__(self, llm: LLMBase, middlewares: Optional[Sequence[LLMMiddleware]] = None) -> None:
self._llm = llm
self._middlewares: list[LLMMiddleware] = list(middlewares) if middlewares else []

def add_middleware(self, middleware: LLMMiddleware) -> None:
self._middlewares.append(middleware)

def execute(self, request: LLMRequest) -> LLMResponse:
def execute_request(r: LLMRequest) -> LLMResponse:
start = time.time()
result = self._llm.post_chat_request(r.context, request.messages, **r.kwargs)
return LLMResponse(request, result, time.time() - start)

handler: ExecuteLLMRequest = execute_request
for middleware in reversed(self._middlewares):
handler = self._wrap_middleware(middleware, handler)
return handler(request)

def _wrap_middleware(self, middleware: LLMMiddleware, handler: ExecuteLLMRequest) -> ExecuteLLMRequest:
def wrapped(request: LLMRequest) -> LLMResponse:
return middleware(self._llm, handler, request)

return wrapped


class LLMLoggingMiddleware:
def __call__(self, llm: LLMBase, execute: ExecuteLLMRequest, request: LLMRequest) -> LLMResponse:
request.context.logger.info(
f"Sending request with {len(request.messages)} message(s) to {llm.configuration.model_name()}"
)
response = execute(request)
if response.result is not None:
request.context.logger.info(f"Response: `{response.result.first_choice}` in {response.duration} seconds")
else:
request.context.logger.warning("No response")
return response


class LLMRetryMiddleware:
def __init__(self, retries: int, delay: float, exception_to_check: Optional[type[Exception]] = None) -> None:
self._retries = retries
self._delay = delay
self._exception_to_check = exception_to_check if exception_to_check else Exception

def __call__(self, llm: LLMBase, execute: ExecuteLLMRequest, request: LLMRequest) -> LLMResponse:
attempt = 0
exceptions: List[Exception] = []
while attempt < self._retries:
try:
return execute(request)
except Exception as e:
if not isinstance(e, self._exception_to_check):
raise
exceptions.append(e)
attempt += 1
if attempt >= self._retries:
break
time.sleep(self._delay)

raise LLMOutOfRetriesException(llm_name=llm.model_name, retry_count=attempt, exceptions=exceptions)
95 changes: 95 additions & 0 deletions tests/integration/llm/test_llm_function.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
from __future__ import annotations

import json
import unittest

import dotenv

from council import AzureLLM
from council.llm import LLMParsingException
from council.llm.llm_function import LLMFunction
from council.utils import CodeParser

SP = """
You are a sql expert solving the `Task` leveraging the database schema in the `DATASET` section.
# Instructions
- Assess whether the `Task` is reasonable and possible to solve given the database schema
- The entire response must be inside a valid `json` code block as defined in the `Response formatting` section
- Keep your explanation concise with only important details and assumptions, no excuse or other comment
# Response formatting
Your entire response must be inside the following `json` code block:
The JSON response schema must contain the following keys: `solved`, `explanation` and `sql`.
```json
{
"solved": {Boolean, indicating whether the task is solved based on the provided database schema},
"explanation": {String, concise explanation of the solution if solved or reasoning if not solved},
"sql": {String, the sql query if the task is solved, otherwise empty}
}
```
# DATASET - nyc_airbnb
## Dataset Description
The dataset is the New York City Airbnb Open Data which includes information on Airbnb listings in NYC for the year 2019.
It provides data such as host id and name, geographical coordinates, room types, pricing, etc.
## Tables
### Table Name: NYC_2019
#### Table Description
Since 2008, guests and hosts have used Airbnb to expand on traveling possibilities and present more unique, personalized way of experiencing the world. This table describes the listing activity and metrics in NYC, NY for 2019.
Content
This data file includes all needed information to find out more about hosts, geographical availability, necessary metrics to make predictions and draw conclusions.
#### Columns
For each column, the name, data type and description are given as follow : {name}: {data type}: {description}`
id: BIGINT: listing ID
name: TEXT: name of the listing
neighbourhood_group: TEXT: location
neighbourhood: TEXT: area
price: BIGINT: price in dollars
"""

U = "Price distribution by borough"


class SQLResult:
def __init__(self, solved: bool, explanation: str, sql: str):
self.solved = solved
self.explanation = explanation
self.sql = sql

@staticmethod
def from_response(llm_response: str) -> SQLResult:
json_bloc = CodeParser.find_first("json", llm_response)
if json_bloc is None:
raise LLMParsingException("No json block found in response")
response = json.loads(json_bloc.code)
sql = response.get("sql")
if sql is not None:
if "LIMIT" not in sql:
raise LLMParsingException("Generated SQL query should contain a LIMIT clause")
return SQLResult(response["solved"], response["explanation"], sql)
return SQLResult(False, "No SQL query generated", "")

def __str__(self):
if self.solved:
return f"Sql: {self.sql}\n\nExplanation: {self.explanation}"
return f"Not solved.\nExplanation: {self.explanation}"


class TestLlmAzure(unittest.TestCase):
"""requires an Azure LLM model deployed"""

def setUp(self) -> None:
dotenv.load_dotenv()
self.llm = AzureLLM.from_env()

def test_basic_prompt(self):
llm_func = LLMFunction(self.llm, SQLResult.from_response, SP)
sql_result = llm_func.execute(U)
self.assertIsInstance(sql_result, SQLResult)
print("", sql_result, sep="\n")
Loading

0 comments on commit ed1fb99

Please sign in to comment.