diff --git a/council/llm/llm_answer.py b/council/llm/llm_answer.py index 7ebda01c..c1414eeb 100644 --- a/council/llm/llm_answer.py +++ b/council/llm/llm_answer.py @@ -8,7 +8,9 @@ class LLMParsingException(Exception): - pass + def __init__(self, message: str = "Your response is not correctly formatted.") -> None: + super().__init__(message) + self.message = message class llm_property(property): diff --git a/council/llm/llm_base.py b/council/llm/llm_base.py index d7cc3042..c1a57ccd 100644 --- a/council/llm/llm_base.py +++ b/council/llm/llm_base.py @@ -60,6 +60,10 @@ def __init__( def configuration(self) -> T_Configuration: return self._configuration + @property + def model_name(self) -> str: + return self.configuration.model_name() + def post_chat_request(self, context: LLMContext, messages: Sequence[LLMMessage], **kwargs: Any) -> LLMResult: """ Sends a chat request to the language model. diff --git a/council/llm/llm_exception.py b/council/llm/llm_exception.py index 324ec4d4..11ecd87c 100644 --- a/council/llm/llm_exception.py +++ b/council/llm/llm_exception.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Optional, Sequence class LLMException(Exception): @@ -17,9 +17,8 @@ def __init__(self, message: str, llm_name: Optional[str]) -> None: Returns: None """ - super().__init__( - f"llm:{llm_name}, message {message}" if llm_name is not None and len(llm_name) > 0 else message - ) + self.message = f"llm:{llm_name}, message {message}" if llm_name is not None and len(llm_name) > 0 else message + super().__init__(self.message) class LLMCallTimeoutException(LLMException): @@ -89,3 +88,17 @@ def __init__(self, token_count: int, limit: int, model: str, llm_name: Optional[ None """ super().__init__(f"token_count={token_count} is exceeding model {model} limit of {limit} tokens.", llm_name) + + +class LLMOutOfRetriesException(LLMException): + """ + Custom exception raised when the maximum number of retries is reached. + """ + + def __init__( + self, llm_name: Optional[str], retry_count: int, exceptions: Optional[Sequence[Exception]] = None + ) -> None: + """ + Initializes an instance of LLMOutOfRetriesException. + """ + super().__init__(f"Exceeded maximum retries after {retry_count} attempts", llm_name) diff --git a/council/llm/llm_function.py b/council/llm/llm_function.py new file mode 100644 index 00000000..78c5ffdc --- /dev/null +++ b/council/llm/llm_function.py @@ -0,0 +1,95 @@ +from typing import Any, Callable, Generic, List, Optional, Sequence, TypeVar + +from council import LLMContext +from council.llm import LLMBase, LLMMessage, LLMParsingException +from council.llm.llm_middleware import LLMMiddlewareChain, LLMRequest + +T_Response = TypeVar("T_Response") + +LLMResponseParser = Callable[[str], T_Response] + + +class LLMFunctionError(Exception): + """ + Exception raised when an error occurs during the execution of a function. + """ + + def __init__(self, message: str, retryable: bool = False): + """ + Initialize the FunctionError instance. + """ + super().__init__(message) + self.message = message + self.retryable = retryable + + +class FunctionOutOfRetryError(LLMFunctionError): + """ + Exception raised when the maximum number of function execution retries is reached. + Stores all previous exceptions raised during retry attempts. + """ + + def __init__(self, retry_count: int, exceptions: Optional[Sequence[Exception]] = None): + """ + Initialize the FunctionOutOfRetryException instance. + + Args: + retry_count (int): The number of retries attempted. + exceptions (List[Exception]): List of exceptions raised during retry attempts. + """ + super().__init__(f"Exceeded maximum retries after {retry_count} attempts") + self.exceptions = exceptions if exceptions is not None else [] + + def __str__(self) -> str: + message = super().__str__() + if self.exceptions: + message += "\nPrevious exceptions:\n" + for i, exception in enumerate(self.exceptions, start=1): + message += f"{i}. {exception}\n" + return message + + +class LLMFunction(Generic[T_Response]): + def __init__( + self, llm: LLMBase, response_parser: LLMResponseParser, system_message: str, max_retries: int = 3 + ) -> None: + self._llm_middleware = LLMMiddlewareChain(llm) + self._system_message = LLMMessage.system_message(system_message) + self._response_parser = response_parser + self._max_retries = max_retries + self._context = LLMContext.empty() + + def execute(self, user_message: str, **kwargs: Any) -> T_Response: + messages = [self._system_message, LLMMessage.user_message(user_message)] + new_messages: List[LLMMessage] = [] + exceptions: List[Exception] = [] + + retry = 0 + while retry <= self._max_retries: + messages = messages + new_messages + request = LLMRequest(context=self._context, messages=messages, **kwargs) + try: + llm_response = self._llm_middleware.execute(request) + if llm_response.result is not None: + response = llm_response.result.first_choice + return self._response_parser(response) + except LLMParsingException as e: + exceptions.append(e) + new_messages = self._handle_error(e, response, e.message) + except LLMFunctionError as e: + exceptions.append(e) + if not e.retryable: + raise e + new_messages = self._handle_error(e, response, e.message) + except Exception as e: + exceptions.append(e) + new_messages = self._handle_error(e, response, f"Fix the following exception: `{e}`") + + retry += 1 + + raise FunctionOutOfRetryError(self._max_retries, exceptions) + + def _handle_error(self, e: Exception, response: str, user_message: str) -> List[LLMMessage]: + error = f"{e.__class__.__name__}: `{e}`" + self._context.logger.warning(f"Exception occurred: {error} for response {response}") + return [LLMMessage.assistant_message(response), LLMMessage.user_message(f"{user_message} Fix\n{error}")] diff --git a/council/llm/llm_middleware.py b/council/llm/llm_middleware.py new file mode 100644 index 00000000..ddee6995 --- /dev/null +++ b/council/llm/llm_middleware.py @@ -0,0 +1,121 @@ +from __future__ import annotations + +import time +from typing import Any, Callable, List, Optional, Protocol, Sequence + +from council import LLMContext + +from .llm_base import LLMBase, LLMMessage, LLMResult +from .llm_exception import LLMOutOfRetriesException + + +class LLMRequest: + def __init__(self, context: LLMContext, messages: Sequence[LLMMessage], **kwargs: Any): + self._context = context + self._messages = messages + self._kwargs = kwargs + + @property + def context(self) -> LLMContext: + return self._context + + @property + def messages(self) -> Sequence[LLMMessage]: + return self._messages + + @property + def kwargs(self) -> Any: + return self._kwargs + + @staticmethod + def default(messages: Sequence[LLMMessage], **kwargs: Any) -> LLMRequest: + return LLMRequest(LLMContext.empty(), messages, **kwargs) + + +class LLMResponse: + def __init__(self, request: LLMRequest, result: Optional[LLMResult], duration: float): + self._request = request + self._result = result + self._duration = duration + + @property + def result(self) -> Optional[LLMResult]: + return self._result + + @property + def duration(self) -> float: + return self._duration + + @staticmethod + def empty(request: LLMRequest) -> LLMResponse: + return LLMResponse(request, None, -1.0) + + +ExecuteLLMRequest = Callable[[LLMRequest], LLMResponse] + + +class LLMMiddleware(Protocol): + def __call__(self, llm: LLMBase, execute: ExecuteLLMRequest, request: LLMRequest) -> LLMResponse: ... + + +class LLMMiddlewareChain: + def __init__(self, llm: LLMBase, middlewares: Optional[Sequence[LLMMiddleware]] = None) -> None: + self._llm = llm + self._middlewares: list[LLMMiddleware] = list(middlewares) if middlewares else [] + + def add_middleware(self, middleware: LLMMiddleware) -> None: + self._middlewares.append(middleware) + + def execute(self, request: LLMRequest) -> LLMResponse: + def execute_request(r: LLMRequest) -> LLMResponse: + start = time.time() + result = self._llm.post_chat_request(r.context, request.messages, **r.kwargs) + return LLMResponse(request, result, time.time() - start) + + handler: ExecuteLLMRequest = execute_request + for middleware in reversed(self._middlewares): + handler = self._wrap_middleware(middleware, handler) + return handler(request) + + def _wrap_middleware(self, middleware: LLMMiddleware, handler: ExecuteLLMRequest) -> ExecuteLLMRequest: + def wrapped(request: LLMRequest) -> LLMResponse: + return middleware(self._llm, handler, request) + + return wrapped + + +class LLMLoggingMiddleware: + def __call__(self, llm: LLMBase, execute: ExecuteLLMRequest, request: LLMRequest) -> LLMResponse: + request.context.logger.info( + f"Sending request with {len(request.messages)} message(s) to {llm.configuration.model_name()}" + ) + response = execute(request) + if response.result is not None: + request.context.logger.info(f"Response: `{response.result.first_choice}` in {response.duration} seconds") + else: + request.context.logger.warning("No response") + return response + + +class LLMRetryMiddleware: + def __init__(self, retries: int, delay: float, exception_to_check: Optional[type[Exception]] = None) -> None: + self._retries = retries + self._delay = delay + self._exception_to_check = exception_to_check if exception_to_check else Exception + + def __call__(self, llm: LLMBase, execute: ExecuteLLMRequest, request: LLMRequest) -> LLMResponse: + attempt = 0 + exceptions: List[Exception] = [] + while attempt < self._retries: + try: + return execute(request) + except Exception as e: + if not isinstance(e, self._exception_to_check): + raise + exceptions.append(e) + attempt += 1 + if attempt >= self._retries: + break + time.sleep(self._delay) + + raise LLMOutOfRetriesException(llm_name=llm.model_name, retry_count=attempt, exceptions=exceptions) diff --git a/tests/integration/llm/test_llm_function.py b/tests/integration/llm/test_llm_function.py new file mode 100644 index 00000000..816a0f5b --- /dev/null +++ b/tests/integration/llm/test_llm_function.py @@ -0,0 +1,95 @@ +from __future__ import annotations + +import json +import unittest + +import dotenv + +from council import AzureLLM +from council.llm import LLMParsingException +from council.llm.llm_function import LLMFunction +from council.utils import CodeParser + +SP = """ + You are a sql expert solving the `Task` leveraging the database schema in the `DATASET` section. + + # Instructions + - Assess whether the `Task` is reasonable and possible to solve given the database schema + - The entire response must be inside a valid `json` code block as defined in the `Response formatting` section + - Keep your explanation concise with only important details and assumptions, no excuse or other comment + + # Response formatting + + Your entire response must be inside the following `json` code block: + The JSON response schema must contain the following keys: `solved`, `explanation` and `sql`. + + ```json + { + "solved": {Boolean, indicating whether the task is solved based on the provided database schema}, + "explanation": {String, concise explanation of the solution if solved or reasoning if not solved}, + "sql": {String, the sql query if the task is solved, otherwise empty} + } + ``` + + # DATASET - nyc_airbnb + ## Dataset Description + The dataset is the New York City Airbnb Open Data which includes information on Airbnb listings in NYC for the year 2019. + It provides data such as host id and name, geographical coordinates, room types, pricing, etc. + + ## Tables + ### Table Name: NYC_2019 + + #### Table Description + Since 2008, guests and hosts have used Airbnb to expand on traveling possibilities and present more unique, personalized way of experiencing the world. This table describes the listing activity and metrics in NYC, NY for 2019. + Content + This data file includes all needed information to find out more about hosts, geographical availability, necessary metrics to make predictions and draw conclusions. + + #### Columns + For each column, the name, data type and description are given as follow : {name}: {data type}: {description}` + id: BIGINT: listing ID + name: TEXT: name of the listing + neighbourhood_group: TEXT: location + neighbourhood: TEXT: area + price: BIGINT: price in dollars +""" + +U = "Price distribution by borough" + + +class SQLResult: + def __init__(self, solved: bool, explanation: str, sql: str): + self.solved = solved + self.explanation = explanation + self.sql = sql + + @staticmethod + def from_response(llm_response: str) -> SQLResult: + json_bloc = CodeParser.find_first("json", llm_response) + if json_bloc is None: + raise LLMParsingException("No json block found in response") + response = json.loads(json_bloc.code) + sql = response.get("sql") + if sql is not None: + if "LIMIT" not in sql: + raise LLMParsingException("Generated SQL query should contain a LIMIT clause") + return SQLResult(response["solved"], response["explanation"], sql) + return SQLResult(False, "No SQL query generated", "") + + def __str__(self): + if self.solved: + return f"Sql: {self.sql}\n\nExplanation: {self.explanation}" + return f"Not solved.\nExplanation: {self.explanation}" + + +class TestLlmAzure(unittest.TestCase): + """requires an Azure LLM model deployed""" + + def setUp(self) -> None: + dotenv.load_dotenv() + self.llm = AzureLLM.from_env() + + def test_basic_prompt(self): + llm_func = LLMFunction(self.llm, SQLResult.from_response, SP) + sql_result = llm_func.execute(U) + self.assertIsInstance(sql_result, SQLResult) + print("", sql_result, sep="\n") diff --git a/tests/unit/llm/test_llm_middleware.py b/tests/unit/llm/test_llm_middleware.py new file mode 100644 index 00000000..15f3738e --- /dev/null +++ b/tests/unit/llm/test_llm_middleware.py @@ -0,0 +1,52 @@ +import unittest + +from council.llm import LLMMessage, LLMException, LLMFallback +from council.llm.llm_exception import LLMOutOfRetriesException, LLMCallTimeoutException +from council.llm.llm_middleware import LLMRequest, LLMMiddlewareChain, LLMLoggingMiddleware, LLMRetryMiddleware +from council.mocks import MockLLM, MockErrorLLM + + +class TestLlmMiddleware(unittest.TestCase): + + def setUp(self) -> None: + self._llm = MockLLM.from_response("USD") + + def test_with_log(self): + messages = [LLMMessage.user_message("Give me an example of a currency")] + request = LLMRequest.default(messages) + + with_logs = LLMMiddlewareChain(self._llm) + with_logs.add_middleware(LLMLoggingMiddleware()) + llm_response = with_logs.execute(request) + result = llm_response.result.first_choice + print(result) + + def test_with_retry(self): + messages = [LLMMessage.user_message("Give me an example of a currency")] + request = LLMRequest.default(messages) + + with_retry = LLMMiddlewareChain(MockErrorLLM()) + with_retry.add_middleware(LLMLoggingMiddleware()) + with_retry.add_middleware(LLMRetryMiddleware(retries=3, delay=1, exception_to_check=LLMException)) + with self.assertRaises(LLMOutOfRetriesException): + _ = with_retry.execute(request) + + def test_with_no_retry(self): + messages = [LLMMessage.user_message("Give me an example of a currency")] + request = LLMRequest.default(messages) + + with_retry = LLMMiddlewareChain(MockErrorLLM()) + with_retry.add_middleware(LLMLoggingMiddleware()) + with_retry.add_middleware(LLMRetryMiddleware(retries=3, delay=1, exception_to_check=LLMCallTimeoutException)) + with self.assertRaises(LLMException): + _ = with_retry.execute(request) + + def test_with_retry_no_error(self): + messages = [LLMMessage.user_message("Give me an example of a currency")] + request = LLMRequest.default(messages) + + with_retry = LLMMiddlewareChain(LLMFallback(MockErrorLLM(), MockLLM.from_response("USD"))) + with_retry.add_middleware(LLMLoggingMiddleware()) + with_retry.add_middleware(LLMRetryMiddleware(retries=3, delay=1, exception_to_check=LLMCallTimeoutException)) + response = with_retry.execute(request) + self.assertEqual("USD", response.result.first_choice)