-
Notifications
You must be signed in to change notification settings - Fork 35
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* AddLLMFunction * Update tests/integration/llm/test_llm_function.py Co-authored-by: Nikolaiev Dmytro <[email protected]> * Update tests/integration/llm/test_llm_function.py Co-authored-by: Nikolaiev Dmytro <[email protected]> * Fix review comments --------- Co-authored-by: Nikolaiev Dmytro <[email protected]>
- Loading branch information
1 parent
0904645
commit ed1fb99
Showing
7 changed files
with
387 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
from typing import Any, Callable, Generic, List, Optional, Sequence, TypeVar | ||
|
||
from council import LLMContext | ||
from council.llm import LLMBase, LLMMessage, LLMParsingException | ||
from council.llm.llm_middleware import LLMMiddlewareChain, LLMRequest | ||
|
||
T_Response = TypeVar("T_Response") | ||
|
||
LLMResponseParser = Callable[[str], T_Response] | ||
|
||
|
||
class LLMFunctionError(Exception): | ||
""" | ||
Exception raised when an error occurs during the execution of a function. | ||
""" | ||
|
||
def __init__(self, message: str, retryable: bool = False): | ||
""" | ||
Initialize the FunctionError instance. | ||
""" | ||
super().__init__(message) | ||
self.message = message | ||
self.retryable = retryable | ||
|
||
|
||
class FunctionOutOfRetryError(LLMFunctionError): | ||
""" | ||
Exception raised when the maximum number of function execution retries is reached. | ||
Stores all previous exceptions raised during retry attempts. | ||
""" | ||
|
||
def __init__(self, retry_count: int, exceptions: Optional[Sequence[Exception]] = None): | ||
""" | ||
Initialize the FunctionOutOfRetryException instance. | ||
Args: | ||
retry_count (int): The number of retries attempted. | ||
exceptions (List[Exception]): List of exceptions raised during retry attempts. | ||
""" | ||
super().__init__(f"Exceeded maximum retries after {retry_count} attempts") | ||
self.exceptions = exceptions if exceptions is not None else [] | ||
|
||
def __str__(self) -> str: | ||
message = super().__str__() | ||
if self.exceptions: | ||
message += "\nPrevious exceptions:\n" | ||
for i, exception in enumerate(self.exceptions, start=1): | ||
message += f"{i}. {exception}\n" | ||
return message | ||
|
||
|
||
class LLMFunction(Generic[T_Response]): | ||
def __init__( | ||
self, llm: LLMBase, response_parser: LLMResponseParser, system_message: str, max_retries: int = 3 | ||
) -> None: | ||
self._llm_middleware = LLMMiddlewareChain(llm) | ||
self._system_message = LLMMessage.system_message(system_message) | ||
self._response_parser = response_parser | ||
self._max_retries = max_retries | ||
self._context = LLMContext.empty() | ||
|
||
def execute(self, user_message: str, **kwargs: Any) -> T_Response: | ||
messages = [self._system_message, LLMMessage.user_message(user_message)] | ||
new_messages: List[LLMMessage] = [] | ||
exceptions: List[Exception] = [] | ||
|
||
retry = 0 | ||
while retry <= self._max_retries: | ||
messages = messages + new_messages | ||
request = LLMRequest(context=self._context, messages=messages, **kwargs) | ||
try: | ||
llm_response = self._llm_middleware.execute(request) | ||
if llm_response.result is not None: | ||
response = llm_response.result.first_choice | ||
return self._response_parser(response) | ||
except LLMParsingException as e: | ||
exceptions.append(e) | ||
new_messages = self._handle_error(e, response, e.message) | ||
except LLMFunctionError as e: | ||
exceptions.append(e) | ||
if not e.retryable: | ||
raise e | ||
new_messages = self._handle_error(e, response, e.message) | ||
except Exception as e: | ||
exceptions.append(e) | ||
new_messages = self._handle_error(e, response, f"Fix the following exception: `{e}`") | ||
|
||
retry += 1 | ||
|
||
raise FunctionOutOfRetryError(self._max_retries, exceptions) | ||
|
||
def _handle_error(self, e: Exception, response: str, user_message: str) -> List[LLMMessage]: | ||
error = f"{e.__class__.__name__}: `{e}`" | ||
self._context.logger.warning(f"Exception occurred: {error} for response {response}") | ||
return [LLMMessage.assistant_message(response), LLMMessage.user_message(f"{user_message} Fix\n{error}")] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,121 @@ | ||
from __future__ import annotations | ||
|
||
import time | ||
from typing import Any, Callable, List, Optional, Protocol, Sequence | ||
|
||
from council import LLMContext | ||
|
||
from .llm_base import LLMBase, LLMMessage, LLMResult | ||
from .llm_exception import LLMOutOfRetriesException | ||
|
||
|
||
class LLMRequest: | ||
def __init__(self, context: LLMContext, messages: Sequence[LLMMessage], **kwargs: Any): | ||
self._context = context | ||
self._messages = messages | ||
self._kwargs = kwargs | ||
|
||
@property | ||
def context(self) -> LLMContext: | ||
return self._context | ||
|
||
@property | ||
def messages(self) -> Sequence[LLMMessage]: | ||
return self._messages | ||
|
||
@property | ||
def kwargs(self) -> Any: | ||
return self._kwargs | ||
|
||
@staticmethod | ||
def default(messages: Sequence[LLMMessage], **kwargs: Any) -> LLMRequest: | ||
return LLMRequest(LLMContext.empty(), messages, **kwargs) | ||
|
||
|
||
class LLMResponse: | ||
def __init__(self, request: LLMRequest, result: Optional[LLMResult], duration: float): | ||
self._request = request | ||
self._result = result | ||
self._duration = duration | ||
|
||
@property | ||
def result(self) -> Optional[LLMResult]: | ||
return self._result | ||
|
||
@property | ||
def duration(self) -> float: | ||
return self._duration | ||
|
||
@staticmethod | ||
def empty(request: LLMRequest) -> LLMResponse: | ||
return LLMResponse(request, None, -1.0) | ||
|
||
|
||
ExecuteLLMRequest = Callable[[LLMRequest], LLMResponse] | ||
|
||
|
||
class LLMMiddleware(Protocol): | ||
def __call__(self, llm: LLMBase, execute: ExecuteLLMRequest, request: LLMRequest) -> LLMResponse: ... | ||
|
||
|
||
class LLMMiddlewareChain: | ||
def __init__(self, llm: LLMBase, middlewares: Optional[Sequence[LLMMiddleware]] = None) -> None: | ||
self._llm = llm | ||
self._middlewares: list[LLMMiddleware] = list(middlewares) if middlewares else [] | ||
|
||
def add_middleware(self, middleware: LLMMiddleware) -> None: | ||
self._middlewares.append(middleware) | ||
|
||
def execute(self, request: LLMRequest) -> LLMResponse: | ||
def execute_request(r: LLMRequest) -> LLMResponse: | ||
start = time.time() | ||
result = self._llm.post_chat_request(r.context, request.messages, **r.kwargs) | ||
return LLMResponse(request, result, time.time() - start) | ||
|
||
handler: ExecuteLLMRequest = execute_request | ||
for middleware in reversed(self._middlewares): | ||
handler = self._wrap_middleware(middleware, handler) | ||
return handler(request) | ||
|
||
def _wrap_middleware(self, middleware: LLMMiddleware, handler: ExecuteLLMRequest) -> ExecuteLLMRequest: | ||
def wrapped(request: LLMRequest) -> LLMResponse: | ||
return middleware(self._llm, handler, request) | ||
|
||
return wrapped | ||
|
||
|
||
class LLMLoggingMiddleware: | ||
def __call__(self, llm: LLMBase, execute: ExecuteLLMRequest, request: LLMRequest) -> LLMResponse: | ||
request.context.logger.info( | ||
f"Sending request with {len(request.messages)} message(s) to {llm.configuration.model_name()}" | ||
) | ||
response = execute(request) | ||
if response.result is not None: | ||
request.context.logger.info(f"Response: `{response.result.first_choice}` in {response.duration} seconds") | ||
else: | ||
request.context.logger.warning("No response") | ||
return response | ||
|
||
|
||
class LLMRetryMiddleware: | ||
def __init__(self, retries: int, delay: float, exception_to_check: Optional[type[Exception]] = None) -> None: | ||
self._retries = retries | ||
self._delay = delay | ||
self._exception_to_check = exception_to_check if exception_to_check else Exception | ||
|
||
def __call__(self, llm: LLMBase, execute: ExecuteLLMRequest, request: LLMRequest) -> LLMResponse: | ||
attempt = 0 | ||
exceptions: List[Exception] = [] | ||
while attempt < self._retries: | ||
try: | ||
return execute(request) | ||
except Exception as e: | ||
if not isinstance(e, self._exception_to_check): | ||
raise | ||
exceptions.append(e) | ||
attempt += 1 | ||
if attempt >= self._retries: | ||
break | ||
time.sleep(self._delay) | ||
|
||
raise LLMOutOfRetriesException(llm_name=llm.model_name, retry_count=attempt, exceptions=exceptions) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
from __future__ import annotations | ||
|
||
import json | ||
import unittest | ||
|
||
import dotenv | ||
|
||
from council import AzureLLM | ||
from council.llm import LLMParsingException | ||
from council.llm.llm_function import LLMFunction | ||
from council.utils import CodeParser | ||
|
||
SP = """ | ||
You are a sql expert solving the `Task` leveraging the database schema in the `DATASET` section. | ||
# Instructions | ||
- Assess whether the `Task` is reasonable and possible to solve given the database schema | ||
- The entire response must be inside a valid `json` code block as defined in the `Response formatting` section | ||
- Keep your explanation concise with only important details and assumptions, no excuse or other comment | ||
# Response formatting | ||
Your entire response must be inside the following `json` code block: | ||
The JSON response schema must contain the following keys: `solved`, `explanation` and `sql`. | ||
```json | ||
{ | ||
"solved": {Boolean, indicating whether the task is solved based on the provided database schema}, | ||
"explanation": {String, concise explanation of the solution if solved or reasoning if not solved}, | ||
"sql": {String, the sql query if the task is solved, otherwise empty} | ||
} | ||
``` | ||
# DATASET - nyc_airbnb | ||
## Dataset Description | ||
The dataset is the New York City Airbnb Open Data which includes information on Airbnb listings in NYC for the year 2019. | ||
It provides data such as host id and name, geographical coordinates, room types, pricing, etc. | ||
## Tables | ||
### Table Name: NYC_2019 | ||
#### Table Description | ||
Since 2008, guests and hosts have used Airbnb to expand on traveling possibilities and present more unique, personalized way of experiencing the world. This table describes the listing activity and metrics in NYC, NY for 2019. | ||
Content | ||
This data file includes all needed information to find out more about hosts, geographical availability, necessary metrics to make predictions and draw conclusions. | ||
#### Columns | ||
For each column, the name, data type and description are given as follow : {name}: {data type}: {description}` | ||
id: BIGINT: listing ID | ||
name: TEXT: name of the listing | ||
neighbourhood_group: TEXT: location | ||
neighbourhood: TEXT: area | ||
price: BIGINT: price in dollars | ||
""" | ||
|
||
U = "Price distribution by borough" | ||
|
||
|
||
class SQLResult: | ||
def __init__(self, solved: bool, explanation: str, sql: str): | ||
self.solved = solved | ||
self.explanation = explanation | ||
self.sql = sql | ||
|
||
@staticmethod | ||
def from_response(llm_response: str) -> SQLResult: | ||
json_bloc = CodeParser.find_first("json", llm_response) | ||
if json_bloc is None: | ||
raise LLMParsingException("No json block found in response") | ||
response = json.loads(json_bloc.code) | ||
sql = response.get("sql") | ||
if sql is not None: | ||
if "LIMIT" not in sql: | ||
raise LLMParsingException("Generated SQL query should contain a LIMIT clause") | ||
return SQLResult(response["solved"], response["explanation"], sql) | ||
return SQLResult(False, "No SQL query generated", "") | ||
|
||
def __str__(self): | ||
if self.solved: | ||
return f"Sql: {self.sql}\n\nExplanation: {self.explanation}" | ||
return f"Not solved.\nExplanation: {self.explanation}" | ||
|
||
|
||
class TestLlmAzure(unittest.TestCase): | ||
"""requires an Azure LLM model deployed""" | ||
|
||
def setUp(self) -> None: | ||
dotenv.load_dotenv() | ||
self.llm = AzureLLM.from_env() | ||
|
||
def test_basic_prompt(self): | ||
llm_func = LLMFunction(self.llm, SQLResult.from_response, SP) | ||
sql_result = llm_func.execute(U) | ||
self.assertIsInstance(sql_result, SQLResult) | ||
print("", sql_result, sep="\n") |
Oops, something went wrong.