Skip to content

Commit

Permalink
Enhance LLMFunction (#153)
Browse files Browse the repository at this point in the history
* Enhance LLMFunction

- Add list of messages to the exec function
- Change Response parser from str to LLMResponse

* Update llm_function.py

* Add type to _init_

* Fixes
  • Loading branch information
aflament authored Jul 4, 2024
1 parent d8e7006 commit 711c94f
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 31 deletions.
6 changes: 4 additions & 2 deletions council/llm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,12 @@
from .llm_exception import LLMException, LLMCallException, LLMCallTimeoutException, LLMTokenLimitException
from .llm_message import LLMMessageRole, LLMMessage, LLMessageTokenCounterBase
from .llm_base import LLMBase, LLMResult, LLMConfigurationBase
from .monitored_llm import MonitoredLLM
from .chat_gpt_configuration import ChatGPTConfigurationBase
from .llm_fallback import LLMFallback
from .llm_middleware import LLMMiddleware, LLMRequest, LLMResponse
from .llm_function import LLMFunction, LLMFunctionError, FunctionOutOfRetryError
from .monitored_llm import MonitoredLLM

from .chat_gpt_configuration import ChatGPTConfigurationBase
from .openai_chat_completions_llm import OpenAIChatCompletionsModel
from .openai_token_counter import OpenAITokenCounter

Expand Down
47 changes: 29 additions & 18 deletions council/llm/llm_function.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
from typing import Any, Callable, Generic, List, Optional, Sequence, TypeVar
from typing import Any, Callable, Generic, Iterable, List, Optional, Sequence, TypeVar, Union

from council import LLMContext
from council.llm import LLMBase, LLMMessage, LLMParsingException
from council.llm.llm_middleware import LLMMiddlewareChain, LLMRequest
from council.contexts import LLMContext

from .llm_answer import LLMParsingException
from .llm_base import LLMBase, LLMMessage
from .llm_middleware import LLMMiddlewareChain, LLMRequest, LLMResponse

T_Response = TypeVar("T_Response")

LLMResponseParser = Callable[[str], T_Response]
LLMResponseParser = Callable[[LLMResponse], T_Response]


class LLMFunctionError(Exception):
Expand Down Expand Up @@ -59,37 +61,46 @@ def __init__(
self._max_retries = max_retries
self._context = LLMContext.empty()

def execute(self, user_message: str, **kwargs: Any) -> T_Response:
messages = [self._system_message, LLMMessage.user_message(user_message)]
def execute(
self, user_message: Union[str, LLMMessage], messages: Optional[Iterable[LLMMessage]] = None, **kwargs: Any
) -> T_Response:
um = user_message if isinstance(user_message, LLMMessage) else LLMMessage.user_message(user_message)
llm_messages = [self._system_message, um]
if messages:
llm_messages = llm_messages + list(messages)
new_messages: List[LLMMessage] = []
exceptions: List[Exception] = []

retry = 0
while retry <= self._max_retries:
messages = messages + new_messages
request = LLMRequest(context=self._context, messages=messages, **kwargs)
llm_messages = llm_messages + new_messages
request = LLMRequest(context=self._context, messages=llm_messages, **kwargs)
try:
llm_response = self._llm_middleware.execute(request)
if llm_response.result is not None:
response = llm_response.result.first_choice
return self._response_parser(response)
return self._response_parser(llm_response)
except LLMParsingException as e:
exceptions.append(e)
new_messages = self._handle_error(e, response, e.message)
new_messages = self._handle_error(e, llm_response, e.message)
except LLMFunctionError as e:
exceptions.append(e)
if not e.retryable:
raise e
new_messages = self._handle_error(e, response, e.message)
new_messages = self._handle_error(e, llm_response, e.message)
except Exception as e:
exceptions.append(e)
new_messages = self._handle_error(e, response, f"Fix the following exception: `{e}`")
new_messages = self._handle_error(e, llm_response, f"Fix the following exception: `{e}`")

retry += 1

raise FunctionOutOfRetryError(self._max_retries, exceptions)

def _handle_error(self, e: Exception, response: str, user_message: str) -> List[LLMMessage]:
def _handle_error(self, e: Exception, response: LLMResponse, user_message: str) -> List[LLMMessage]:
error = f"{e.__class__.__name__}: `{e}`"
self._context.logger.warning(f"Exception occurred: {error} for response {response}")
return [LLMMessage.assistant_message(response), LLMMessage.user_message(f"{user_message} Fix\n{error}")]
if response.result is None:
self._context.logger.warning(f"Exception occurred: {error} without response.")
return [LLMMessage.assistant_message("No response"), LLMMessage.user_message("Please retry.")]

first_choice = response.result.first_choice
error += f"\nResponse: {first_choice}"
self._context.logger.warning(f"Exception occurred: {error} for response {first_choice}")
return [LLMMessage.assistant_message(first_choice), LLMMessage.user_message(f"{user_message} Fix\n{error}")]
2 changes: 1 addition & 1 deletion council/llm/llm_middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import time
from typing import Any, Callable, List, Optional, Protocol, Sequence

from council import LLMContext
from council.contexts import LLMContext

from .llm_base import LLMBase, LLMMessage, LLMResult
from .llm_exception import LLMOutOfRetriesException
Expand Down
28 changes: 18 additions & 10 deletions tests/integration/llm/test_llm_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@
import dotenv

from council import AzureLLM
from council.llm import LLMParsingException
from council.llm import LLMParsingException, LLMResponse, LLMMessage
from council.llm.llm_function import LLMFunction
from council.utils import CodeParser

SP = """
SYSTEM_PROMPT = """
You are a sql expert solving the `Task` leveraging the database schema in the `DATASET` section.
# Instructions
Expand Down Expand Up @@ -53,7 +53,7 @@
price: BIGINT: price in dollars
"""

U = "Price distribution by borough"
USER = "Price distribution by borough"


class SQLResult:
Expand All @@ -63,16 +63,18 @@ def __init__(self, solved: bool, explanation: str, sql: str):
self.sql = sql

@staticmethod
def from_response(llm_response: str) -> SQLResult:
def from_response(response: LLMResponse) -> SQLResult:
llm_response = response.result.first_choice if response.result else ""
json_bloc = CodeParser.find_first("json", llm_response)
if json_bloc is None:
raise LLMParsingException("No json block found in response")
response = json.loads(json_bloc.code)
sql = response.get("sql")

code_response = json.loads(json_bloc.code)
sql = code_response.get("sql")
if sql is not None:
if "LIMIT" not in sql:
raise LLMParsingException("Generated SQL query should contain a LIMIT clause")
return SQLResult(response["solved"], response["explanation"], sql)
return SQLResult(code_response["solved"], code_response["explanation"], sql)
return SQLResult(False, "No SQL query generated", "")

def __str__(self):
Expand All @@ -81,15 +83,21 @@ def __str__(self):
return f"Not solved.\nExplanation: {self.explanation}"


class TestLlmAzure(unittest.TestCase):
class TestLlmFunction(unittest.TestCase):
"""requires an Azure LLM model deployed"""

def setUp(self) -> None:
dotenv.load_dotenv()
self.llm = AzureLLM.from_env()

def test_basic_prompt(self):
llm_func = LLMFunction(self.llm, SQLResult.from_response, SP)
sql_result = llm_func.execute(U)
llm_func = LLMFunction(self.llm, SQLResult.from_response, SYSTEM_PROMPT)
sql_result = llm_func.execute(USER)
self.assertIsInstance(sql_result, SQLResult)
print("", sql_result, sep="\n")

def test_message_prompt(self):
llm_func = LLMFunction(self.llm, SQLResult.from_response, SYSTEM_PROMPT)
sql_result = llm_func.execute(LLMMessage.user_message(USER))
self.assertIsInstance(sql_result, SQLResult)
print("", sql_result, sep="\n")

0 comments on commit 711c94f

Please sign in to comment.