Skip to content

Commit

Permalink
Merge branch 'main' into pr/143
Browse files Browse the repository at this point in the history
  • Loading branch information
aflament committed Jun 21, 2024
2 parents d6c2952 + ed1fb99 commit 8759ef0
Show file tree
Hide file tree
Showing 37 changed files with 1,013 additions and 303 deletions.
8 changes: 4 additions & 4 deletions council/llm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,18 @@
from .llm_answer import llm_property, LLMAnswer, LLMProperty, LLMParsingException
from .llm_exception import LLMException, LLMCallException, LLMCallTimeoutException, LLMTokenLimitException
from .llm_message import LLMMessageRole, LLMMessage, LLMessageTokenCounterBase
from .llm_base import LLMBase, LLMResult
from .llm_base import LLMBase, LLMResult, LLMConfigurationBase
from .monitored_llm import MonitoredLLM
from .llm_configuration_base import LLMConfigurationBase
from .chat_gpt_configuration import ChatGPTConfigurationBase
from .llm_fallback import LLMFallback

from .openai_chat_completions_llm import OpenAIChatCompletionsModel
from .openai_token_counter import OpenAITokenCounter

from .azure_llm_configuration import AzureLLMConfiguration
from .azure_chat_gpt_configuration import AzureChatGPTConfiguration
from .azure_llm import AzureLLM

from .openai_llm_configuration import OpenAILLMConfiguration
from .openai_chat_gpt_configuration import OpenAIChatGPTConfiguration
from .openai_llm import OpenAILLM

from .anthropic_llm_configuration import AnthropicLLMConfiguration
Expand Down
15 changes: 7 additions & 8 deletions council/llm/anthropic_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,16 +32,15 @@ def count_messages_token(self, messages: Sequence[LLMMessage]) -> int:
return tokens


class AnthropicLLM(LLMBase):
class AnthropicLLM(LLMBase[AnthropicLLMConfiguration]):
def __init__(self, config: AnthropicLLMConfiguration, name: Optional[str] = None) -> None:
"""
Initialize a new instance.
Args:
config(AnthropicLLMConfiguration): configuration for the instance
"""
super().__init__(name=name or f"{self.__class__.__name__}")
self.config = config
super().__init__(name=name or f"{self.__class__.__name__}", configuration=config)
self._client = Anthropic(api_key=config.api_key.value, max_retries=0)
self._api = self._get_api_wrapper()

Expand All @@ -51,12 +50,12 @@ def _post_chat_request(self, context: LLMContext, messages: Sequence[LLMMessage]
prompt_text = "\n".join([msg.content for msg in messages])
return LLMResult(choices=response, consumptions=self.to_consumptions(prompt_text, response))
except APITimeoutError as e:
raise LLMCallTimeoutException(self.config.timeout.value, self._name) from e
raise LLMCallTimeoutException(self._configuration.timeout.value, self._name) from e
except APIStatusError as e:
raise LLMCallException(code=e.status_code, error=e.message, llm_name=self._name) from e

def to_consumptions(self, prompt: str, responses: List[str]) -> Sequence[Consumption]:
model = self.config.model.unwrap()
model = self._configuration.model_name()
prompt_tokens = self._client.count_tokens(prompt)
completion_tokens = sum(self._client.count_tokens(r) for r in responses)
return [
Expand All @@ -67,9 +66,9 @@ def to_consumptions(self, prompt: str, responses: List[str]) -> Sequence[Consump
]

def _get_api_wrapper(self) -> AnthropicAPIClientWrapper:
if self.config.model.value == "claude-2":
return AnthropicCompletionLLM(client=self._client, config=self.config)
return AnthropicMessagesLLM(client=self._client, config=self.config)
if self._configuration is not None and self._configuration.model_name() == "claude-2":
return AnthropicCompletionLLM(client=self._client, config=self.configuration)
return AnthropicMessagesLLM(client=self._client, config=self.configuration)

@staticmethod
def from_env() -> AnthropicLLM:
Expand Down
17 changes: 11 additions & 6 deletions council/llm/anthropic_llm_configuration.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
from __future__ import annotations

from typing import Any, Optional
from typing import Any, Final, Optional

from council.llm import LLMConfigSpec
from council.llm.llm_configuration_base import _DEFAULT_TIMEOUT
from council.llm import LLMConfigSpec, LLMConfigurationBase
from council.utils import Parameter, greater_than_validator, prefix_validator, read_env_int, read_env_str

_env_var_prefix = "ANTHROPIC_"
_env_var_prefix: Final[str] = "ANTHROPIC_"


def _tv(x: float) -> None:
Expand All @@ -18,7 +17,7 @@ def _tv(x: float) -> None:
raise ValueError("must be in the range [0.0..1.0]")


class AnthropicLLMConfiguration:
class AnthropicLLMConfiguration(LLMConfigurationBase):
"""
Configuration for :class:AnthropicLLM
"""
Expand All @@ -42,12 +41,18 @@ def __init__(self, model: str, api_key: str, max_tokens: int) -> None:
)

self._timeout = Parameter.int(
name="timeout", required=False, default=_DEFAULT_TIMEOUT, validator=greater_than_validator(0)
name="timeout", required=False, default=self.default_timeout, validator=greater_than_validator(0)
)
self._temperature = Parameter.float(name="temperature", required=False, default=0.0, validator=_tv)
self._top_p = Parameter.float(name="top_p", required=False, validator=_tv)
self._top_k = Parameter.int(name="top_k", required=False, validator=greater_than_validator(0))

def model_name(self) -> str:
"""
Anthropic model name
"""
return self._model.unwrap()

@property
def model(self) -> Parameter[str]:
"""
Expand Down
Original file line number Diff line number Diff line change
@@ -1,26 +1,25 @@
from __future__ import annotations

from typing import Optional
from typing import Final, Optional

from council.llm import LLMConfigurationBase
from council.llm import ChatGPTConfigurationBase
from council.llm.llm_config_object import LLMConfigSpec
from council.llm.llm_configuration_base import _DEFAULT_TIMEOUT
from council.utils import Parameter, greater_than_validator, not_empty_validator, read_env_str

_env_var_prefix = "AZURE_"
_env_var_prefix: Final[str] = "AZURE_"


class AzureLLMConfiguration(LLMConfigurationBase):
class AzureChatGPTConfiguration(ChatGPTConfigurationBase):
"""
Configuration for :class:AzureLLM
Notes:
https://learn.microsoft.com/en-us/azure/cognitive-services/openai/reference#completions
"""

def __init__(self, api_key: str, api_base: str, deployment_name: str):
def __init__(self, api_key: str, api_base: str, deployment_name: str, model_name: Optional[str] = None) -> None:
"""
Initialize a new instance of OpenAILLMConfiguration
Initialize a new instance of AzureChatGPTConfiguration
Args:
api_key (str): the Azure api key
"""
Expand All @@ -32,8 +31,12 @@ def __init__(self, api_key: str, api_base: str, deployment_name: str):
)
self._api_version = Parameter.string(name="api_version", required=False, default="2023-05-15")
self._timeout = Parameter.int(
name="timeout", required=False, default=_DEFAULT_TIMEOUT, validator=greater_than_validator(0)
name="timeout", required=False, default=self.default_timeout, validator=greater_than_validator(0)
)
self._model_name = model_name or deployment_name

def model_name(self) -> str:
return self._model_name

@property
def api_base(self) -> Parameter[str]:
Expand Down Expand Up @@ -76,23 +79,25 @@ def _read_optional_env(self):
self._timeout.from_env(_env_var_prefix + "LLM_TIMEOUT")

@staticmethod
def from_env(deployment_name: Optional[str] = None) -> AzureLLMConfiguration:
def from_env(deployment_name: Optional[str] = None) -> AzureChatGPTConfiguration:
api_key = read_env_str(_env_var_prefix + "LLM_API_KEY").unwrap()
api_base = read_env_str(_env_var_prefix + "LLM_API_BASE").unwrap()
if deployment_name is None:
deployment_name = read_env_str(_env_var_prefix + "LLM_DEPLOYMENT_NAME", required=False).unwrap()

config = AzureLLMConfiguration(api_key=api_key, api_base=api_base, deployment_name=deployment_name)
config = AzureChatGPTConfiguration(api_key=api_key, api_base=api_base, deployment_name=deployment_name)
config.read_env(env_var_prefix=_env_var_prefix)
config._read_optional_env()
return config

@staticmethod
def from_spec(spec: LLMConfigSpec) -> AzureLLMConfiguration:
def from_spec(spec: LLMConfigSpec) -> AzureChatGPTConfiguration:
api_key: str = spec.provider.must_get_value("apiKey")
deployment_name: str = spec.provider.must_get_value("deploymentName")
api_base: str = spec.provider.must_get_value("apiBase")
config = AzureLLMConfiguration(api_key=api_key, api_base=str(api_base), deployment_name=str(deployment_name))
config = AzureChatGPTConfiguration(
api_key=api_key, api_base=str(api_base), deployment_name=str(deployment_name)
)

if spec.parameters is not None:
config.from_dict(spec.parameters)
Expand Down
10 changes: 5 additions & 5 deletions council/llm/azure_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from httpx import HTTPStatusError, TimeoutException

from . import LLMCallException, LLMCallTimeoutException, OpenAIChatCompletionsModel
from .azure_llm_configuration import AzureLLMConfiguration
from .azure_chat_gpt_configuration import AzureChatGPTConfiguration
from .llm_config_object import LLMConfigObject, LLMProviders


Expand All @@ -15,7 +15,7 @@ class AzureOpenAIChatCompletionsModelProvider:
Represents an OpenAI language model hosted on Azure.
"""

def __init__(self, config: AzureLLMConfiguration, name: Optional[str]) -> None:
def __init__(self, config: AzureChatGPTConfiguration, name: Optional[str]) -> None:
self.config = config
self._uri = (
f"{self.config.api_base.value}/openai/deployments/{self.config.deployment_name.value}/chat/completions"
Expand All @@ -41,13 +41,13 @@ class AzureLLM(OpenAIChatCompletionsModel):
Represents an OpenAI language model hosted on Azure.
"""

def __init__(self, config: AzureLLMConfiguration, name: Optional[str] = None) -> None:
def __init__(self, config: AzureChatGPTConfiguration, name: Optional[str] = None) -> None:
name = name or f"{self.__class__.__name__}"
super().__init__(config, AzureOpenAIChatCompletionsModelProvider(config, name).post_request, None, name)

@staticmethod
def from_env(deployment_name: Optional[str] = None) -> AzureLLM:
config: AzureLLMConfiguration = AzureLLMConfiguration.from_env(deployment_name)
config: AzureChatGPTConfiguration = AzureChatGPTConfiguration.from_env(deployment_name)
return AzureLLM(config, deployment_name)

@staticmethod
Expand All @@ -56,5 +56,5 @@ def from_config(config_object: LLMConfigObject) -> AzureLLM:
if not provider.is_of_kind(LLMProviders.Azure):
raise ValueError(f"Invalid LLM provider, actual {provider}, expected {LLMProviders.Azure}")

config = AzureLLMConfiguration.from_spec(config_object.spec)
config = AzureChatGPTConfiguration.from_spec(config_object.spec)
return AzureLLM(config=config, name=config_object.metadata.name)
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
import abc
from abc import ABC
from typing import Any, Dict, Optional

from council.utils.parameter import Parameter

_DEFAULT_TIMEOUT = 30
from council.llm.llm_base import LLMConfigurationBase
from council.utils import Parameter


def _tv(x: float):
Expand Down Expand Up @@ -33,7 +32,7 @@ def _mtv(x: int):
raise ValueError("must be positive")


class LLMConfigurationBase(abc.ABC):
class ChatGPTConfigurationBase(LLMConfigurationBase, ABC):
"""
Configuration for OpenAI LLM Chat Completion GPT Model
"""
Expand Down Expand Up @@ -123,19 +122,22 @@ def add_param(parameter: Parameter):
return payload

def from_dict(self, values: Dict[str, Any]):
value: Optional[Any] = None
value = values.get("temperature", None)
if value is not None:
value: Optional[Any] = values.get("temperature", None)
if value:
self.temperature.set(float(value))

value = values.get("n", None)
if value is not None:
if value:
self.n.set(int(value))

value = values.get("maxTokens", None)
if value is not None:
if value:
self.max_tokens.set(int(value))

value = values.get("topP", None)
if value is not None:
if value:
self.top_p.set(float(value))

value = values.get("presencePenalty", None)
if value is not None:
self.presence_penalty.set(float(value))
Expand Down
4 changes: 3 additions & 1 deletion council/llm/llm_answer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@


class LLMParsingException(Exception):
pass
def __init__(self, message: str = "Your response is not correctly formatted.") -> None:
super().__init__(message)
self.message = message


class llm_property(property):
Expand Down
36 changes: 33 additions & 3 deletions council/llm/llm_base.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,26 @@
import abc
from typing import Any, Optional, Sequence
from typing import Any, Final, Generic, Optional, Sequence, TypeVar

from council.contexts import Consumption, LLMContext, Monitorable

from .llm_message import LLMessageTokenCounterBase, LLMMessage

_DEFAULT_TIMEOUT: Final[int] = 30


class LLMConfigurationBase(abc.ABC):

@abc.abstractmethod
def model_name(self) -> str:
pass

@property
def default_timeout(self) -> int:
return _DEFAULT_TIMEOUT


T_Configuration = TypeVar("T_Configuration", bound=LLMConfigurationBase)


class LLMResult:
def __init__(self, choices: Sequence[str], consumptions: Optional[Sequence[Consumption]] = None):
Expand All @@ -24,15 +40,29 @@ def consumptions(self) -> Sequence[Consumption]:
return self._consumptions


class LLMBase(Monitorable, abc.ABC):
class LLMBase(Generic[T_Configuration], Monitorable, abc.ABC):
"""
Abstract base class representing a language model.
"""

def __init__(self, token_counter: Optional[LLMessageTokenCounterBase] = None, name: Optional[str] = None):
def __init__(
self,
configuration: T_Configuration,
token_counter: Optional[LLMessageTokenCounterBase] = None,
name: Optional[str] = None,
) -> None:
super().__init__(name or "llm")
self._token_counter = token_counter
self._name = name or f"llm_{self.__class__.__name__}"
self._configuration = configuration

@property
def configuration(self) -> T_Configuration:
return self._configuration

@property
def model_name(self) -> str:
return self.configuration.model_name()

def post_chat_request(self, context: LLMContext, messages: Sequence[LLMMessage], **kwargs: Any) -> LLMResult:
"""
Expand Down
21 changes: 17 additions & 4 deletions council/llm/llm_exception.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional
from typing import Optional, Sequence


class LLMException(Exception):
Expand All @@ -17,9 +17,8 @@ def __init__(self, message: str, llm_name: Optional[str]) -> None:
Returns:
None
"""
super().__init__(
f"llm:{llm_name}, message {message}" if llm_name is not None and len(llm_name) > 0 else message
)
self.message = f"llm:{llm_name}, message {message}" if llm_name is not None and len(llm_name) > 0 else message
super().__init__(self.message)


class LLMCallTimeoutException(LLMException):
Expand Down Expand Up @@ -89,3 +88,17 @@ def __init__(self, token_count: int, limit: int, model: str, llm_name: Optional[
None
"""
super().__init__(f"token_count={token_count} is exceeding model {model} limit of {limit} tokens.", llm_name)


class LLMOutOfRetriesException(LLMException):
"""
Custom exception raised when the maximum number of retries is reached.
"""

def __init__(
self, llm_name: Optional[str], retry_count: int, exceptions: Optional[Sequence[Exception]] = None
) -> None:
"""
Initializes an instance of LLMOutOfRetriesException.
"""
super().__init__(f"Exceeded maximum retries after {retry_count} attempts", llm_name)
Loading

0 comments on commit 8759ef0

Please sign in to comment.