diff --git a/council/__init__.py b/council/__init__.py index a1869ca3..ee6a6fe6 100644 --- a/council/__init__.py +++ b/council/__init__.py @@ -1,4 +1,5 @@ """Init file.""" + from .agents import Agent, AgentChain, AgentResult from .chains import Chain, ChainBase from .contexts import AgentContext, Budget, ChainContext, ChatHistory, ChatMessage, LLMContext, SkillContext diff --git a/council/agent_tests/agent_tests.py b/council/agent_tests/agent_tests.py index ad2dbc54..ed93cbf0 100644 --- a/council/agent_tests/agent_tests.py +++ b/council/agent_tests/agent_tests.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import time from enum import Enum from typing import List, Dict, Any, Sequence, Optional @@ -186,7 +188,7 @@ def __init__(self, test_cases: Optional[List[AgentTestCase]] = None): else: self._test_cases = [] - def add_test_case(self, prompt: str, scorers: List[ScorerBase]) -> "AgentTestSuite": + def add_test_case(self, prompt: str, scorers: List[ScorerBase]) -> AgentTestSuite: self._test_cases.append(AgentTestCase(prompt, scorers)) return self diff --git a/council/agents/agent.py b/council/agents/agent.py index 735eb202..6415fd36 100644 --- a/council/agents/agent.py +++ b/council/agents/agent.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import itertools from concurrent import futures from typing import Dict, List, Optional, Sequence @@ -165,7 +167,7 @@ def _execute_unit(iteration_context: AgentContext, unit: ExecutionUnit): context.logger.info(f'message="chain execution ended" chain="{chain.name}" execution_unit="{unit.name}"') @staticmethod - def from_skill(skill: SkillBase, chain_description: Optional[str] = None) -> "Agent": + def from_skill(skill: SkillBase, chain_description: Optional[str] = None) -> Agent: """ Helper function to create a new agent with a :class:`.BasicController`, a :class:`.BasicEvaluator` and a single :class:`.SkillBase` wrapped into a :class:`.Chain` @@ -182,7 +184,7 @@ def from_skill(skill: SkillBase, chain_description: Optional[str] = None) -> "Ag @staticmethod def from_chain( chain: ChainBase, evaluator: EvaluatorBase = BasicEvaluator(), filter: FilterBase = BasicFilter() - ) -> "Agent": + ) -> Agent: """ Helper function to create a new agent with a :class:`.BasicController`, a :class:`.BasicEvaluator` and a single :class:`.SkillBase` wrapped into a :class:`.Chain` diff --git a/council/contexts/_agent_context.py b/council/contexts/_agent_context.py index e068cfa8..1aee4a5f 100644 --- a/council/contexts/_agent_context.py +++ b/council/contexts/_agent_context.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from typing import Iterable, Optional, Sequence from ._agent_context_store import AgentContextStore @@ -15,11 +17,11 @@ class AgentContext(ContextBase): the execution context given to an :class:`~council.agents.Agent` """ - def __init__(self, store: AgentContextStore, execution_context: ExecutionContext, budget: Budget): + def __init__(self, store: AgentContextStore, execution_context: ExecutionContext, budget: Budget) -> None: super().__init__(store, execution_context, budget) @staticmethod - def empty(budget: Optional[Budget] = None) -> "AgentContext": + def empty(budget: Optional[Budget] = None) -> AgentContext: """ creates a new instance with no data @@ -29,7 +31,7 @@ def empty(budget: Optional[Budget] = None) -> "AgentContext": return AgentContext.from_chat_history(ChatHistory(), budget) @staticmethod - def from_chat_history(chat_history: ChatHistory, budget: Optional[Budget] = None) -> "AgentContext": + def from_chat_history(chat_history: ChatHistory, budget: Optional[Budget] = None) -> AgentContext: """ creates a new instance from a :class:`ChatHistory` @@ -41,7 +43,7 @@ def from_chat_history(chat_history: ChatHistory, budget: Optional[Budget] = None return AgentContext(store, ExecutionContext(store.execution_log, "agent"), budget or Budget.default()) @staticmethod - def from_user_message(message: str, budget: Optional[Budget] = None) -> "AgentContext": + def from_user_message(message: str, budget: Optional[Budget] = None) -> AgentContext: """ creates a new instance from a user message. The :class:`ChatHistory` contains only the given message @@ -52,7 +54,7 @@ def from_user_message(message: str, budget: Optional[Budget] = None) -> "AgentCo """ return AgentContext.from_chat_history(ChatHistory.from_user_message(message), budget) - def new_agent_context_for(self, monitored: Monitored) -> "AgentContext": + def new_agent_context_for(self, monitored: Monitored) -> AgentContext: """ creates a new instance for the given object, adjusting the execution context appropriately @@ -67,7 +69,7 @@ def new_iteration(self) -> None: """ self._store.new_iteration() - def new_agent_context_for_new_iteration(self) -> "AgentContext": + def new_agent_context_for_new_iteration(self) -> AgentContext: """ creates a new instance, adjusting the execution context appropriately """ @@ -75,7 +77,7 @@ def new_agent_context_for_new_iteration(self) -> "AgentContext": name = f"iterations[{len(self._store.iterations) - 1}]" return AgentContext(self._store, self._execution_context.new_from_name(name), self._budget) - def new_agent_context_for_execution_unit(self, name: str) -> "AgentContext": + def new_agent_context_for_execution_unit(self, name: str) -> AgentContext: """ creates a new instance, adjusting the execution context for the given name diff --git a/council/contexts/_budget.py b/council/contexts/_budget.py index 737c989b..45592c32 100644 --- a/council/contexts/_budget.py +++ b/council/contexts/_budget.py @@ -25,7 +25,7 @@ class Consumption: """ - def __init__(self, value: float, unit: str, kind: str): + def __init__(self, value: float, unit: str, kind: str) -> None: """ Initializes a Consumption instance. @@ -54,16 +54,28 @@ def kind(self) -> str: def __str__(self): return f"{self._kind} consumption: {self._value} {self.unit}" - def add(self, value: float) -> "Consumption": + def add(self, value: float) -> Consumption: + """ + Returns a new Consumption instance with the value incremented by the specified value. + """ return Consumption(self._value + value, self.unit, self._kind) - def subtract(self, value: float) -> "Consumption": + def subtract(self, value: float) -> Consumption: + """ + Returns a new Consumption instance with the value decremented by the specified value. + """ return Consumption(self._value - value, self.unit, self._kind) def add_value(self, value: float) -> None: + """ + Increments the value of the consumption by the specified value. + """ self._value += value def subtract_value(self, value: float) -> None: + """ + Decrements the value of the consumption by the specified value. + """ self._value -= value def to_dict(self) -> Dict[str, Any]: @@ -160,7 +172,7 @@ def __repr__(self): return f"Budget({self._duration})" @staticmethod - def default() -> "Budget": + def default() -> Budget: """ Helper function that create a new Budget with a default value. diff --git a/council/contexts/_chain_context.py b/council/contexts/_chain_context.py index f5f11714..fe10ca0d 100644 --- a/council/contexts/_chain_context.py +++ b/council/contexts/_chain_context.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from typing import Iterable, List, Optional import more_itertools @@ -106,7 +108,7 @@ def from_agent_context(context: AgentContext, monitored: Monitored, name: str, b context._store, context._execution_context.new_for(monitored), name, budget or Budget.default() ) - def fork_for(self, monitored: Monitored, budget: Optional[Budget] = None) -> "ChainContext": + def fork_for(self, monitored: Monitored, budget: Optional[Budget] = None) -> ChainContext: """ forks the context for the given object, adjust the execution context appropriately """ @@ -156,7 +158,7 @@ def extend(self, messages: Iterable[ChatMessage]) -> None: self.append(message) @staticmethod - def from_chat_history(history: ChatHistory, budget: Optional[Budget] = None) -> "ChainContext": + def from_chat_history(history: ChatHistory, budget: Optional[Budget] = None) -> ChainContext: """ helper function that creates a new instance from a :class:`ChatHistory`. @@ -169,14 +171,14 @@ def from_chat_history(history: ChatHistory, budget: Optional[Budget] = None) -> return ChainContext.from_agent_context(context, MockMonitored("mock chain"), "mock chain", budget) @staticmethod - def from_user_message(message: str, budget: Optional[Budget] = None) -> "ChainContext": + def from_user_message(message: str, budget: Optional[Budget] = None) -> ChainContext: """ creates a new instance from a user message. The :class:`ChatHistory` contains only the user message """ return ChainContext.from_chat_history(ChatHistory.from_user_message(message), budget) @staticmethod - def empty() -> "ChainContext": + def empty() -> ChainContext: """ helper function that creates a new empty instance. diff --git a/council/contexts/_chat_history.py b/council/contexts/_chat_history.py index 4588122a..bdcb883d 100644 --- a/council/contexts/_chat_history.py +++ b/council/contexts/_chat_history.py @@ -1,3 +1,4 @@ +from __future__ import annotations from ._message_list import MessageList @@ -7,7 +8,7 @@ class ChatHistory(MessageList): """ @staticmethod - def from_user_message(message: str) -> "ChatHistory": + def from_user_message(message: str) -> ChatHistory: """ helpers function that returns a new instance containing one user message """ diff --git a/council/contexts/_chat_message.py b/council/contexts/_chat_message.py index d7e282f2..7f8547e9 100644 --- a/council/contexts/_chat_message.py +++ b/council/contexts/_chat_message.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from enum import Enum from typing import Any, Dict @@ -60,7 +62,7 @@ def __init__(self, message: str, kind: ChatMessageKind, data: Any = None, source self._is_error = is_error @staticmethod - def agent(message: str, data: Any = None, source: str = "", is_error: bool = False) -> "ChatMessage": + def agent(message: str, data: Any = None, source: str = "", is_error: bool = False) -> ChatMessage: """ Helper function to create message of kind :attr:`ChatMessageKind.Agent`. See :meth:`ChatMessage.__init__` for details @@ -68,7 +70,7 @@ def agent(message: str, data: Any = None, source: str = "", is_error: bool = Fal return ChatMessage(message, ChatMessageKind.Agent, data, source, is_error) @staticmethod - def user(message: str, data: Any = None, source: str = "", is_error: bool = False) -> "ChatMessage": + def user(message: str, data: Any = None, source: str = "", is_error: bool = False) -> ChatMessage: """ Helper function to create message of kind :attr:`ChatMessageKind.User`. See :meth:`ChatMessage.__init__` for details @@ -76,7 +78,7 @@ def user(message: str, data: Any = None, source: str = "", is_error: bool = Fals return ChatMessage(message, ChatMessageKind.User, data, source, is_error) @staticmethod - def skill(message: str, data: Any = None, source: str = "", is_error: bool = False) -> "ChatMessage": + def skill(message: str, data: Any = None, source: str = "", is_error: bool = False) -> ChatMessage: """ Helper function to create message of kind :attr:`ChatMessageKind.Skill`. See :meth:`ChatMessage.__init__` for details @@ -84,7 +86,7 @@ def skill(message: str, data: Any = None, source: str = "", is_error: bool = Fal return ChatMessage(message, ChatMessageKind.Skill, data, source, is_error) @staticmethod - def chain(message: str, data: Any = None, source: str = "", is_error: bool = False) -> "ChatMessage": + def chain(message: str, data: Any = None, source: str = "", is_error: bool = False) -> ChatMessage: """ Helper function to create message of kind :attr:`ChatMessageKind.Chain`. See :meth:`ChatMessage.__init__` for details diff --git a/council/contexts/_composite_message_collection.py b/council/contexts/_composite_message_collection.py index 9333edb5..282571d7 100644 --- a/council/contexts/_composite_message_collection.py +++ b/council/contexts/_composite_message_collection.py @@ -17,11 +17,9 @@ def __init__(self, collections: List[MessageCollection]): @property def messages(self) -> Iterable[ChatMessage]: for collection in self._collections: - for message in collection.messages: - yield message + yield from collection.messages @property def reversed(self) -> Iterable[ChatMessage]: for collection in reversed(self._collections): - for message in collection.reversed: - yield message + yield from collection.reversed diff --git a/council/contexts/_execution_context.py b/council/contexts/_execution_context.py index 8515d4f3..f8c7328f 100644 --- a/council/contexts/_execution_context.py +++ b/council/contexts/_execution_context.py @@ -1,3 +1,4 @@ +from __future__ import annotations from typing import Optional from ._monitorable import Monitorable @@ -20,16 +21,16 @@ def __init__( self._executionLog = execution_log or ExecutionLog() self._entry = self._executionLog.new_entry(path, node) - def _new_path(self, name: str): + def _new_path(self, name: str) -> str: return name if self._entry.source == "" else f"{self._entry.source}/{name}" - def new_from_name(self, name: str): + def new_from_name(self, name: str) -> ExecutionContext: """ returns a new instance for the given name """ return ExecutionContext(self._executionLog, self._new_path(name)) - def new_for(self, monitored: Monitored) -> "ExecutionContext": + def new_for(self, monitored: Monitored) -> ExecutionContext: """ returns a new instance for the given object """ diff --git a/council/contexts/_llm_context.py b/council/contexts/_llm_context.py index 2b101e7b..01485ebb 100644 --- a/council/contexts/_llm_context.py +++ b/council/contexts/_llm_context.py @@ -1,3 +1,4 @@ +from __future__ import annotations from typing import Optional from ._agent_context_store import AgentContextStore @@ -13,18 +14,18 @@ class LLMContext(ContextBase): represents a context used by a :class:`~council.llm.LLMBase` """ - def __init__(self, store: AgentContextStore, execution_context: ExecutionContext, budget: Budget): + def __init__(self, store: AgentContextStore, execution_context: ExecutionContext, budget: Budget) -> None: super().__init__(store, execution_context, budget) @staticmethod - def from_context(context: ContextBase, monitored: Monitored, budget: Optional[Budget] = None) -> "LLMContext": + def from_context(context: ContextBase, monitored: Monitored, budget: Optional[Budget] = None) -> LLMContext: """ creates a new instance from the given context, adjusting the execution context appropriately """ return LLMContext(context._store, context._execution_context.new_for(monitored), budget or context._budget) @staticmethod - def empty() -> "LLMContext": + def empty() -> LLMContext: """ helper function that creates a new empty instance @@ -32,7 +33,7 @@ def empty() -> "LLMContext": """ return LLMContext(AgentContextStore(ChatHistory()), ExecutionContext(), InfiniteBudget()) - def new_for(self, monitored: Monitored) -> "LLMContext": + def new_for(self, monitored: Monitored) -> LLMContext: """ returns a new instance for the given object, adjusting the execution context appropriately """ diff --git a/council/contexts/_scorer_context.py b/council/contexts/_scorer_context.py index 8abdaba4..3afb3098 100644 --- a/council/contexts/_scorer_context.py +++ b/council/contexts/_scorer_context.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from typing import Optional from ._agent_context_store import AgentContextStore @@ -17,14 +19,14 @@ def __init__(self, store: AgentContextStore, execution_context: ExecutionContext super().__init__(store, execution_context, budget) @staticmethod - def from_context(context: ContextBase, monitored: Monitored, budget: Optional[Budget] = None) -> "ScorerContext": + def from_context(context: ContextBase, monitored: Monitored, budget: Optional[Budget] = None) -> ScorerContext: """ creates a new instance from the given context, adjusting the execution appropriately """ return ScorerContext(context._store, context._execution_context.new_for(monitored), budget or context._budget) @staticmethod - def empty() -> "ScorerContext": + def empty() -> ScorerContext: """ helper function that creates a new empty instance @@ -32,7 +34,7 @@ def empty() -> "ScorerContext": """ return ScorerContext(AgentContextStore(ChatHistory()), ExecutionContext(), InfiniteBudget()) - def new_for(self, monitored: Monitored) -> "ScorerContext": + def new_for(self, monitored: Monitored) -> ScorerContext: """ returns a new instance for the given object, adjusting the execution context appropriately """ diff --git a/council/contexts/_skill_context.py b/council/contexts/_skill_context.py index 39f62bc4..7551dce3 100644 --- a/council/contexts/_skill_context.py +++ b/council/contexts/_skill_context.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from typing import Any, Iterable from council.utils import Option @@ -13,7 +15,7 @@ class IterationContext: Provides context information when running inside a loop. """ - def __init__(self, index: int, value: Any): + def __init__(self, index: int, value: Any) -> None: self._index = index self._value = value @@ -59,7 +61,7 @@ def __init__( budget: Budget, messages: Iterable[ChatMessage], iteration: Option[IterationContext], - ): + ) -> None: super().__init__(store, execution_context, name, budget, messages) self._iteration = iteration @@ -74,7 +76,7 @@ def iteration(self) -> Option[IterationContext]: return self._iteration @staticmethod - def from_chain_context(context: ChainContext, iteration: Option[IterationContext]) -> "SkillContext": + def from_chain_context(context: ChainContext, iteration: Option[IterationContext]) -> SkillContext: return SkillContext( context._store, context._execution_context, diff --git a/council/controllers/llm_controller.py b/council/controllers/llm_controller.py index 61303806..c0137852 100644 --- a/council/controllers/llm_controller.py +++ b/council/controllers/llm_controller.py @@ -7,7 +7,7 @@ from council.utils import Option from council.controllers import ControllerBase, ControllerException from .execution_unit import ExecutionUnit -from council.llm.llm_answer import llm_property, LLMAnswer, LLMParsingException +from council.llm.llm_answer import llm_property, LLMAnswer, LLMParsingException, llm_class_validator class Specialist: @@ -42,6 +42,11 @@ def __str__(self): f"The specialist `{self._name}` was scored `{self._score}` with the justification `{self._justification}`" ) + @llm_class_validator + def validate(self): + if self._score < 0 or self._score > 10: + raise LLMParsingException(f"Specialist's score `{self._score}` is invalid, value must be between 0 and 10.") + class LLMController(ControllerBase): """ diff --git a/council/evaluators/llm_evaluator.py b/council/evaluators/llm_evaluator.py index 2c90d0f4..c9986070 100644 --- a/council/evaluators/llm_evaluator.py +++ b/council/evaluators/llm_evaluator.py @@ -3,12 +3,13 @@ This evaluator uses the given `LLM` to evaluate the chain's responses. """ + from typing import List, Optional from council.contexts import AgentContext, ChatMessage, ScoredChatMessage, ContextBase from council.evaluators import EvaluatorBase, EvaluatorException from council.llm import LLMBase, MonitoredLLM, llm_property, LLMAnswer, LLMMessage -from council.llm.llm_answer import LLMParsingException +from council.llm.llm_answer import LLMParsingException, llm_class_validator from council.utils import Option @@ -36,6 +37,11 @@ def justification(self) -> str: def __str__(self): return f"Message `{self._index}` graded `{self._grade}` with the justification: `{self._justification}`" + @llm_class_validator + def validate(self): + if self._grade < 0.0 or self._grade > 10.0: + raise LLMParsingException(f"Grade `{self._grade}` is invalid, value must be between 0.0 and 10.0") + class LLMEvaluator(EvaluatorBase): """Evaluator using an `LLM` to evaluate chain responses.""" diff --git a/council/filters/llm_filter.py b/council/filters/llm_filter.py index 00dcb070..f9e18685 100644 --- a/council/filters/llm_filter.py +++ b/council/filters/llm_filter.py @@ -3,6 +3,7 @@ This filter uses the given `LLM` to filter the chain's responses. """ + from typing import List, Optional from council.contexts import AgentContext, ScoredChatMessage, ContextBase diff --git a/council/llm/__init__.py b/council/llm/__init__.py index 9aa082b0..92ae98d6 100644 --- a/council/llm/__init__.py +++ b/council/llm/__init__.py @@ -1,14 +1,14 @@ """This package provides clients to use various LLMs""" + from typing import Optional from ..utils import read_env_str from .llm_config_object import LLMProvider, LLMConfigObject, LLMConfigSpec, LLMProviders -from .llm_answer import llm_property, LLMAnswer, LLMProperty +from .llm_answer import llm_property, LLMAnswer, LLMProperty, LLMParsingException from .llm_exception import LLMException, LLMCallException, LLMCallTimeoutException, LLMTokenLimitException from .llm_message import LLMMessageRole, LLMMessage, LLMessageTokenCounterBase from .llm_base import LLMBase, LLMResult -from .llm_answer import LLMAnswer, LLMParsingException from .monitored_llm import MonitoredLLM from .llm_configuration_base import LLMConfigurationBase from .llm_fallback import LLMFallback diff --git a/council/llm/llm_answer.py b/council/llm/llm_answer.py index dfee9144..160d28bd 100644 --- a/council/llm/llm_answer.py +++ b/council/llm/llm_answer.py @@ -1,7 +1,7 @@ from __future__ import annotations import inspect -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Callable import yaml @@ -18,6 +18,11 @@ def __init__(self, fget=None, fset=None, fdel=None, doc=None): self.rank = inspect.getsourcelines(fget)[1] +class llm_class_validator: + def __init__(self, func: Callable): + self.f = func + + class LLMProperty: def __init__(self, name: str, prop: llm_property): self._name = name @@ -66,12 +71,15 @@ class LLMAnswer: def __init__(self, schema: Any): self._schema = schema self._class_name = schema.__name__ + self._valid_func = None properties = [] getmembers = inspect.getmembers(schema) for attr_name, attr_value in getmembers: if isinstance(attr_value, llm_property): prop_info = LLMProperty(name=attr_name, prop=attr_value) properties.append(prop_info) + if isinstance(attr_value, llm_class_validator): + self._valid_func = attr_value.f properties.sort(key=lambda item: item.rank) self._properties = properties @@ -87,7 +95,7 @@ def to_yaml_prompt(self) -> str: fp = [ "Use precisely the following template:", "```yaml", - "{your yaml formatted answer}", + f"your yaml formatted answer for the `{self._class_name}` class.", "```", "\n", ] @@ -98,8 +106,10 @@ def to_object(self, line: str) -> Optional[Any]: d = self.parse_line(line, None) missing_keys = [key.name for key in self._properties if key.name not in d.keys()] if len(missing_keys) > 0: - raise LLMParsingException(f"Missing {missing_keys} in response.") + raise LLMParsingException(f"Missing `{missing_keys}` in response.") t = self._schema(**d) + if self._valid_func is not None: + self._valid_func(t) return t def parse_line(self, line: str, default: Optional[Any] = "Invalid") -> Dict[str, Any]: @@ -125,7 +135,7 @@ def parse_yaml(self, bloc: str) -> Dict[str, Any]: properties_dict = {**d} missing_keys = [key.name for key in self._properties if key.name not in properties_dict.keys()] if len(missing_keys) > 0: - raise LLMParsingException(f"Missing {missing_keys} in response.") + raise LLMParsingException(f"Missing `{missing_keys}` in response.") return properties_dict def parse_yaml_list(self, bloc: str) -> List[Dict[str, Any]]: @@ -135,7 +145,7 @@ def parse_yaml_list(self, bloc: str) -> List[Dict[str, Any]]: properties_dict = {**item} missing_keys = [key.name for key in self._properties if key.name not in properties_dict.keys()] if len(missing_keys) > 0: - raise LLMParsingException(f"Missing {missing_keys} in response.") + raise LLMParsingException(f"Missing `{missing_keys}` in response.") result.append(properties_dict) return result diff --git a/council/llm/llm_config_object.py b/council/llm/llm_config_object.py index 45465a0b..455d5e3f 100644 --- a/council/llm/llm_config_object.py +++ b/council/llm/llm_config_object.py @@ -7,6 +7,7 @@ import yaml from council.utils import DataObject, DataObjectSpecBase +from council.utils.parameter import Undefined class LLMProviders(str, Enum): @@ -58,17 +59,20 @@ def to_dict(self) -> Dict[str, Any]: def must_get_value(self, key: str) -> Any: return self.get_value(key=key, required=True) - def get_value(self, key: str, required: bool = False) -> Optional[Any]: + def get_value(self, key: str, required: bool = False, default: Optional[Any] = Undefined()) -> Optional[Any]: maybe_value = self._specs.get(key, None) - if maybe_value is None and required: - raise Exception(f"{key} is required") + if maybe_value is None: + if not isinstance(default, Undefined): + return default if isinstance(maybe_value, dict): - name: Optional[str] = maybe_value.get("fromEnvVar", None) - if name is not None: - value = os.environ.get(name) - return value + default_value: Optional[str] = maybe_value.get("default", None) + env_var_name: Optional[str] = maybe_value.get("fromEnvVar", None) + if env_var_name is not None: + maybe_value = os.environ.get(env_var_name, default_value) + if maybe_value is None and required: + raise Exception(f"LLMProvider {self.name} - A required key {key} is missing.") return maybe_value def __str__(self): diff --git a/council/llm/llm_message.py b/council/llm/llm_message.py index e8e22625..f0b69a1c 100644 --- a/council/llm/llm_message.py +++ b/council/llm/llm_message.py @@ -40,16 +40,16 @@ class LLMMessage: _role: LLMMessageRole _content: str - def __init__(self, role: LLMMessageRole, content: str, name: Optional[str] = None): + def __init__(self, role: LLMMessageRole, content: str, name: Optional[str] = None) -> None: """Initialize a new instance""" self._role = role self._content = content self._name = name @staticmethod - def system_message(content: str, name: Optional[str] = None) -> "LLMMessage": + def system_message(content: str, name: Optional[str] = None) -> LLMMessage: """ - Create a new system message + Create a new system message instance Parameters: content (str): the message content @@ -58,9 +58,9 @@ def system_message(content: str, name: Optional[str] = None) -> "LLMMessage": return LLMMessage(role=LLMMessageRole.System, content=content, name=name) @staticmethod - def user_message(content: str, name: Optional[str] = None) -> "LLMMessage": + def user_message(content: str, name: Optional[str] = None) -> LLMMessage: """ - Create a new user message + Create a new user message instance Parameters: content (str): the message content @@ -69,9 +69,9 @@ def user_message(content: str, name: Optional[str] = None) -> "LLMMessage": return LLMMessage(role=LLMMessageRole.User, content=content, name=name) @staticmethod - def assistant_message(content: str, name: Optional[str] = None) -> "LLMMessage": + def assistant_message(content: str, name: Optional[str] = None) -> LLMMessage: """ - Create a new assistant message + Create a new assistant message instance Parameters: content (str): the message content @@ -105,7 +105,7 @@ def is_of_role(self, role: LLMMessageRole) -> bool: return self._role == role @staticmethod - def from_chat_message(chat_message: ChatMessage) -> Optional["LLMMessage"]: + def from_chat_message(chat_message: ChatMessage) -> Optional[LLMMessage]: """Convert :class:`~.ChatMessage` into :class:`.LLMMessage`""" if chat_message.kind == ChatMessageKind.User: return LLMMessage.user_message(chat_message.message) @@ -114,7 +114,7 @@ def from_chat_message(chat_message: ChatMessage) -> Optional["LLMMessage"]: return None @staticmethod - def from_chat_messages(messages: Iterable[ChatMessage]) -> List["LLMMessage"]: + def from_chat_messages(messages: Iterable[ChatMessage]) -> List[LLMMessage]: m = map(LLMMessage.from_chat_message, messages) return [msg for msg in m if msg is not None] diff --git a/council/llm/openai_chat_completions_llm.py b/council/llm/openai_chat_completions_llm.py index 58d5bdd2..cdc24347 100644 --- a/council/llm/openai_chat_completions_llm.py +++ b/council/llm/openai_chat_completions_llm.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import httpx from typing import List, Any, Protocol, Sequence, Optional @@ -10,8 +12,7 @@ class Provider(Protocol): - def __call__(self, payload: dict[str, Any]) -> httpx.Response: - ... + def __call__(self, payload: dict[str, Any]) -> httpx.Response: ... class Message: @@ -27,7 +28,7 @@ def content(self) -> str: return self._content @staticmethod - def from_dict(obj: Any) -> "Message": + def from_dict(obj: Any) -> Message: _role = str(obj.get("role")) _content = str(obj.get("content")) return Message(_role, _content) @@ -48,7 +49,7 @@ def message(self) -> Message: return self._message @staticmethod - def from_dict(obj: Any) -> "Choice": + def from_dict(obj: Any) -> Choice: _index = int(obj.get("index")) _finish_reason = str(obj.get("finish_reason")) _message = Message.from_dict(obj.get("message")) @@ -81,7 +82,7 @@ def total_tokens(self) -> int: return self._total @staticmethod - def from_dict(obj: Any) -> "Usage": + def from_dict(obj: Any) -> Usage: _completion_tokens = int(obj.get("completion_tokens")) _prompt_tokens = int(obj.get("prompt_tokens")) _total_tokens = int(obj.get("total_tokens")) @@ -129,7 +130,7 @@ def to_consumptions(self) -> Sequence[Consumption]: ] @staticmethod - def from_dict(obj: Any) -> "OpenAIChatCompletionsResult": + def from_dict(obj: Any) -> OpenAIChatCompletionsResult: _id = str(obj.get("id")) _object = str(obj.get("object")) _created = int(obj.get("created")) diff --git a/council/llm/openai_llm.py b/council/llm/openai_llm.py index 1ced4a25..ed030933 100644 --- a/council/llm/openai_llm.py +++ b/council/llm/openai_llm.py @@ -4,7 +4,12 @@ import httpx from httpx import TimeoutException, HTTPStatusError -from . import OpenAIChatCompletionsModel, OpenAITokenCounter, LLMCallTimeoutException, LLMCallException +from . import ( + OpenAIChatCompletionsModel, + OpenAITokenCounter, + LLMCallTimeoutException, + LLMCallException, +) from .llm_config_object import LLMConfigObject, LLMProviders from .openai_llm_configuration import OpenAILLMConfiguration @@ -21,7 +26,10 @@ def __init__(self, config: OpenAILLMConfiguration, name: Optional[str] = None) - self._name = name def post_request(self, payload: dict[str, Any]) -> httpx.Response: - uri = "https://api.openai.com/v1/chat/completions" + """ + Posts a request to the OpenAI chat completions endpoint. + """ + uri = self.config.api_host.unwrap() + "/v1/chat/completions" timeout = self.config.timeout.unwrap() try: @@ -48,8 +56,8 @@ def __init__(self, config: OpenAILLMConfiguration, name: Optional[str] = None): ) @staticmethod - def from_env(model: Optional[str] = None) -> OpenAILLM: - config: OpenAILLMConfiguration = OpenAILLMConfiguration.from_env(model=model) + def from_env(model: Optional[str] = None, api_host: Optional[str] = None) -> OpenAILLM: + config: OpenAILLMConfiguration = OpenAILLMConfiguration.from_env(model=model, api_host=api_host) return OpenAILLM(config) @staticmethod diff --git a/council/llm/openai_llm_configuration.py b/council/llm/openai_llm_configuration.py index ed6631d8..42bbbf46 100644 --- a/council/llm/openai_llm_configuration.py +++ b/council/llm/openai_llm_configuration.py @@ -3,7 +3,13 @@ from council.llm import LLMConfigurationBase from council.llm.llm_config_object import LLMConfigSpec -from council.utils import read_env_str, read_env_int, Parameter, greater_than_validator, prefix_validator +from council.utils import ( + read_env_str, + read_env_int, + Parameter, + greater_than_validator, + prefix_validator, +) from council.llm.llm_configuration_base import _DEFAULT_TIMEOUT _env_var_prefix = "OPENAI_" @@ -17,11 +23,12 @@ class OpenAILLMConfiguration(LLMConfigurationBase): * see https://platform.openai.com/docs/api-reference/chat """ - def __init__(self, api_key: str, model: str, timeout: int = _DEFAULT_TIMEOUT): + def __init__(self, api_key: str, api_host: str, model: str, timeout: int = _DEFAULT_TIMEOUT): """ Initialize a new instance of OpenAILLMConfiguration Args: api_key (str): the OpenAI api key + api_host (str): the OpenAI Host model (str): model version to use timeout (int): seconds to wait for response from OpenAI before timing out """ @@ -34,6 +41,14 @@ def __init__(self, api_key: str, model: str, timeout: int = _DEFAULT_TIMEOUT): name="api_key", required=True, value=api_key, validator=prefix_validator("sk-") ) + self._api_host = Parameter.string( + name="api_host", + required=False, + value=api_host, + default="https://api.openai.com", + validator=prefix_validator("http"), + ) + @property def model(self) -> Parameter[str]: """ @@ -48,6 +63,13 @@ def api_key(self) -> Parameter[str]: """ return self._api_key + @property + def api_host(self) -> Parameter[str]: + """ + OpenAI API Host + """ + return self._api_host + @property def timeout(self) -> Parameter[int]: """ @@ -62,22 +84,29 @@ def build_default_payload(self) -> dict[str, Any]: return payload @staticmethod - def from_env(model: Optional[str] = None) -> OpenAILLMConfiguration: + def from_env(model: Optional[str] = None, api_host: Optional[str] = None) -> OpenAILLMConfiguration: api_key = read_env_str(_env_var_prefix + "API_KEY").unwrap() + if api_host is None: + api_host = read_env_str( + _env_var_prefix + "API_HOST", required=False, default="https://api.openai.com" + ).unwrap() + if model is None: model = read_env_str(_env_var_prefix + "LLM_MODEL", required=False, default="gpt-3.5-turbo").unwrap() + timeout = read_env_int(_env_var_prefix + "LLM_TIMEOUT", required=False, default=_DEFAULT_TIMEOUT).unwrap() - config = OpenAILLMConfiguration(model=model, api_key=api_key, timeout=timeout) + config = OpenAILLMConfiguration(model=model, api_key=api_key, api_host=api_host, timeout=timeout) config.read_env(_env_var_prefix) return config @staticmethod def from_spec(spec: LLMConfigSpec) -> OpenAILLMConfiguration: api_key: str = spec.provider.must_get_value("apiKey") + api_host: str = spec.provider.get_value("apiHost") or "https://api.openai.com" model: str = spec.provider.must_get_value("model") - config = OpenAILLMConfiguration(api_key=api_key, model=str(model)) + config = OpenAILLMConfiguration(api_key=api_key, api_host=api_host, model=str(model)) if spec.parameters is not None: config.from_dict(spec.parameters) diff --git a/council/mocks/__init__.py b/council/mocks/__init__.py index f7f86f38..6970b96d 100644 --- a/council/mocks/__init__.py +++ b/council/mocks/__init__.py @@ -1,6 +1,8 @@ +from __future__ import annotations + import random import time -from typing import Any, Callable, List, Optional, Protocol, Sequence +from typing import Any, Callable, Iterable, List, Optional, Protocol, Sequence from council.agents import Agent, AgentResult from council.contexts import ( @@ -34,8 +36,7 @@ def call(self, _messages: Sequence[LLMMessage]) -> Sequence[str]: class LLMMessagesToStr(Protocol): - def __call__(self, messages: Sequence[LLMMessage]) -> Sequence[str]: - ... + def __call__(self, messages: Sequence[LLMMessage]) -> Sequence[str]: ... def llm_message_content_to_str(messages: Sequence[LLMMessage]) -> Sequence[str]: @@ -72,7 +73,7 @@ def set_action_custom_message(self, message: str) -> None: self._action = lambda context: self.build_success_message(message) @staticmethod - def build_wait_skill(duration: int = 1, message: str = "done") -> "MockSkill": + def build_wait_skill(duration: int = 1, message: str = "done") -> MockSkill: def wait_a_message(context: SkillContext) -> ChatMessage: time.sleep(duration) return ChatMessage.skill(message) @@ -93,15 +94,15 @@ def _post_chat_request(self, context: LLMContext, messages: Sequence[LLMMessage] return LLMResult(choices=[f"{self.__class__.__name__}"]) @staticmethod - def from_responses(responses: List[str]) -> "MockLLM": + def from_responses(responses: List[str]) -> MockLLM: return MockLLM(action=(lambda x: responses)) @staticmethod - def from_response(response: str) -> "MockLLM": + def from_response(response: str) -> MockLLM: return MockLLM(action=(lambda x: [response])) @staticmethod - def from_multi_line_response(responses: List[str]) -> "MockLLM": + def from_multi_line_response(responses: Iterable[str]) -> MockLLM: response = "\n".join(responses) return MockLLM(action=(lambda x: [response])) diff --git a/council/runners/runner_base.py b/council/runners/runner_base.py index f3659504..2ca2246d 100644 --- a/council/runners/runner_base.py +++ b/council/runners/runner_base.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import abc from collections.abc import Set from concurrent import futures @@ -8,14 +10,14 @@ class RunnerBase(Monitorable, abc.ABC): - def run_from_chain_context(self, context: ChainContext, executor: RunnerExecutor): + def run_from_chain_context(self, context: ChainContext, executor: RunnerExecutor) -> None: self.run(context, executor) """ Base runner class that handles common execution logic, including error management and timeout """ - def fork_run_merge(self, runner: Monitored["RunnerBase"], context: ChainContext, executor: RunnerExecutor): + def fork_run_merge(self, runner: Monitored[RunnerBase], context: ChainContext, executor: RunnerExecutor): inner = context.fork_for(runner) try: runner.inner.run(inner, executor) diff --git a/council/scorers/llm_similarity_scorer.py b/council/scorers/llm_similarity_scorer.py index 72ae457b..257c5cfc 100644 --- a/council/scorers/llm_similarity_scorer.py +++ b/council/scorers/llm_similarity_scorer.py @@ -4,7 +4,7 @@ from .scorer_base import ScorerBase from council.contexts import ChatMessage, ScorerContext, ContextBase from council.llm import LLMBase, LLMMessage, MonitoredLLM, llm_property, LLMAnswer -from ..llm.llm_answer import LLMParsingException +from ..llm.llm_answer import LLMParsingException, llm_class_validator from ..utils import Option @@ -26,6 +26,11 @@ def justification(self) -> str: def __str__(self): return f"Similarity score is {self.score} with the justification: {self._justification}" + @llm_class_validator + def validate(self): + if self._score < 0 or self._score > 100: + raise LLMParsingException(f"Similarity Score `{self._score}` is invalid, value must be between 0 and 100.") + class LLMSimilarityScorer(ScorerBase): """ @@ -100,7 +105,7 @@ def _build_system_message(self) -> LLMMessage: "1. Compare the {expected} message and the {actual} message.", "2. Score 0 (2 messages are unrelated) to 100 (the 2 messages have the same content).", "3. Your score must be fair.", - "\n#FORMATTING", + "\n# FORMATTING", self._llm_answer.to_prompt(), ] return LLMMessage.system_message("\n".join(system_prompt)) diff --git a/council/scorers/scorer_base.py b/council/scorers/scorer_base.py index 13f96e0f..d33e4c59 100644 --- a/council/scorers/scorer_base.py +++ b/council/scorers/scorer_base.py @@ -29,9 +29,9 @@ def score(self, context: ScorerContext, message: ChatMessage) -> float: """ try: return self._score(context, message) - except Exception: + except Exception as e: context.logger.exception('message="execution failed"') - raise ScorerException + raise ScorerException from e @abc.abstractmethod def _score(self, context: ScorerContext, message: ChatMessage) -> float: @@ -42,6 +42,6 @@ def _score(self, context: ScorerContext, message: ChatMessage) -> float: def to_dict(self) -> Dict[str, Any]: """ - Serialize the instance into a dictionary. May need to be overriden in derived classes + Serialize the instance into a dictionary. May need to be overridden in derived classes """ return {"type": self.__class__.__name__} diff --git a/council/skills/google/google_context/google_search.py b/council/skills/google/google_context/google_search.py index e7bc283e..b23c7377 100644 --- a/council/skills/google/google_context/google_search.py +++ b/council/skills/google/google_context/google_search.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from abc import ABC from typing import Optional, Any @@ -51,7 +53,7 @@ def from_metadata(result: Any) -> Optional[ResponseReference]: return None @classmethod - def from_env(cls) -> Optional["GoogleSearchEngine"]: + def from_env(cls) -> Optional[GoogleSearchEngine]: try: api_key: str = read_env_str("GOOGLE_API_KEY").unwrap() engine_id: str = read_env_str("GOOGLE_SEARCH_ENGINE_ID").unwrap() diff --git a/council/skills/llm_skill.py b/council/skills/llm_skill.py index cd23ba37..743a9881 100644 --- a/council/skills/llm_skill.py +++ b/council/skills/llm_skill.py @@ -7,8 +7,7 @@ class ReturnMessages(Protocol): - def __call__(self, context: SkillContext) -> List[LLMMessage]: - ... + def __call__(self, context: SkillContext) -> List[LLMMessage]: ... def get_chat_history(context: SkillContext) -> List[LLMMessage]: diff --git a/council/skills/python/python_code_execution_skill.py b/council/skills/python/python_code_execution_skill.py index 0f1db1c2..20c24154 100644 --- a/council/skills/python/python_code_execution_skill.py +++ b/council/skills/python/python_code_execution_skill.py @@ -30,7 +30,8 @@ def __init__(self, env_var: Optional[Mapping[str, str]] = None, decode_stdout: b """ super().__init__("python code runner") - self._env_var = os.environ.copy() | (env_var or {}) + self._env_var = os.environ.copy() + self._env_var.update(env_var or {}) self._decode_stdout = decode_stdout def execute(self, context: SkillContext) -> ChatMessage: diff --git a/council/utils/option.py b/council/utils/option.py index cad3e9f5..1db81e56 100644 --- a/council/utils/option.py +++ b/council/utils/option.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from typing import TypeVar, Generic, Optional, Callable T = TypeVar("T") @@ -91,7 +93,7 @@ def is_some(self) -> bool: return not self.is_none() @staticmethod - def some(some: T) -> "Option[T]": + def some(some: T) -> Option[T]: """ Create a new instance with some value. @@ -103,7 +105,7 @@ def some(some: T) -> "Option[T]": return Option(some) @staticmethod - def none() -> "Option[T]": + def none() -> Option[T]: """ Create a new instance with none diff --git a/council/utils/parameter.py b/council/utils/parameter.py index ad342d09..e52ff8b8 100644 --- a/council/utils/parameter.py +++ b/council/utils/parameter.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from typing import Callable, TypeVar, Optional, Generic, Any, Union from council.utils import Option, read_env_int, read_env_float, read_env_str @@ -75,22 +77,23 @@ def __init__( self._name = name self._required = required self._validator: Validator = validator if validator is not None else lambda x: None + self._default = default if isinstance(value, Undefined): - self._value: Option[T] = Option.none() + if not isinstance(default, Undefined): + self.set(default) + else: + self._value: Option[T] = Option.none() else: self.set(value) self._read_env = converter - self._default = default - if not isinstance(default, Undefined): - self.set(default) - def from_env(self, env_var: str): + def from_env(self, env_var: str) -> None: v = self._read_env(env_var, self._required) if v.is_some(): self.set(v.unwrap()) - def set(self, value: Optional[T]): + def set(self, value: Optional[T]) -> None: try: self._validator(value) self._value = Option(value) @@ -98,7 +101,7 @@ def set(self, value: Optional[T]): raise ParameterValueException(self._name, value=value, message=e) @property - def name(self): + def name(self) -> str: return self._name @property @@ -135,6 +138,16 @@ def __str__(self) -> str: default = f" Default value `{self._default}`." if not isinstance(self._default, Undefined) else "" return f"Parameter{opt} `{self._name}` with {val}.{default}" + def __eq__(self, other: Any) -> bool: + if self.is_none(): + if isinstance(other, Parameter): + return other.is_none() + return False + + if isinstance(other, Parameter): + return self.unwrap() == other.unwrap() + return self.unwrap() == other + @staticmethod def string( name: str, @@ -142,9 +155,14 @@ def string( value: OptionalOrUndefined[str] = _undefined, default: OptionalOrUndefined[str] = _undefined, validator: Optional[Validator] = None, - ) -> "Parameter[str]": + ) -> Parameter[str]: return Parameter( - name=name, required=required, value=value, converter=read_env_str, default=default, validator=validator + name=name, + required=required, + value=value, + converter=read_env_str, + default=default, + validator=validator, ) @staticmethod @@ -154,9 +172,14 @@ def int( value: OptionalOrUndefined[int] = _undefined, default: OptionalOrUndefined[int] = _undefined, validator: Optional[Validator] = None, - ) -> "Parameter[int]": + ) -> Parameter[int]: return Parameter( - name=name, required=required, value=value, converter=read_env_int, default=default, validator=validator + name=name, + required=required, + value=value, + converter=read_env_int, + default=default, + validator=validator, ) @staticmethod @@ -166,7 +189,12 @@ def float( value: OptionalOrUndefined[float] = _undefined, default: OptionalOrUndefined[float] = _undefined, validator: Optional[Validator] = None, - ) -> "Parameter[float]": + ) -> Parameter[float]: return Parameter( - name=name, required=required, value=value, converter=read_env_float, default=default, validator=validator + name=name, + required=required, + value=value, + converter=read_env_float, + default=default, + validator=validator, ) diff --git a/dev-requirements.txt b/dev-requirements.txt index 9176ee2e..edb8e1a9 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -1,12 +1,12 @@ hatch==1.7.0 # Lint -black==23.10.1 -mypy==1.6.1 +black==24.2.0 +mypy==1.9.0 ruff==0.1.4 # Test ipykernel==6.26.0 nbconvert==7.11.0 nbformat==5.9.2 -pytest==7.4.3 +pytest==7.4.4 diff --git a/docs/source/getting_started/first_example.ipynb b/docs/source/getting_started/first_example.ipynb index 67e899b4..9da16e57 100644 --- a/docs/source/getting_started/first_example.ipynb +++ b/docs/source/getting_started/first_example.ipynb @@ -49,7 +49,7 @@ "import os\n", "\n", "dotenv.load_dotenv()\n", - "print(os.getenv(\"OPENAI_API_KEY\", None) is not None )\n", + "print(os.getenv(\"OPENAI_API_KEY\", None) is not None)\n", "\n", "openai_llm = OpenAILLM.from_env()\n" ], @@ -73,7 +73,7 @@ "source": [ "prompt = \"You are responding to every prompt with a short poem titled hello world\"\n", "hw_skill = LLMSkill(llm=openai_llm, system_prompt=prompt)\n", - "hw_chain = Chain(name=\"Hello World\", description=\"Answers with a poem about titled Hello World\", runners=[hw_skill])\n" + "hw_chain = Chain(name=\"Hello World\", description=\"Answers with a poem titled Hello World\", runners=[hw_skill])\n" ], "metadata": { "collapsed": false diff --git a/tests/data/openai-llmodel.yaml b/tests/data/openai-llmodel.yaml index 24f9f1c5..5424e7ea 100644 --- a/tests/data/openai-llmodel.yaml +++ b/tests/data/openai-llmodel.yaml @@ -13,6 +13,8 @@ spec: timeout: 60 apiKey: fromEnvVar: OPENAI_API_KEY + apiHost: + fromEnvVar: OPENAI_API_HOST parameters: n: 3 temperature: 0.5 diff --git a/tests/unit/llm/test_llm_answer.py b/tests/unit/llm/test_llm_answer.py index 6bc2cf4e..85bb18f0 100644 --- a/tests/unit/llm/test_llm_answer.py +++ b/tests/unit/llm/test_llm_answer.py @@ -1,7 +1,7 @@ import unittest from council.controllers.llm_controller import Specialist -from council.llm import LLMAnswer +from council.llm import LLMAnswer, LLMParsingException class TestLLMFallBack(unittest.TestCase): @@ -20,7 +20,11 @@ def test_llm_parse_line_answer(self): print(llma.parse_line("Instructions: None<->Name: first<->Score: ABC<->Justification: because")) cs = llma.to_object("Instructions: None<->nAme: first<->Score: 10<->Justification: because") - self.assertEqual(cs.score, 10) + self.assertEqual(10, cs.score) + + with self.assertRaises(LLMParsingException) as e: + _ = llma.to_object("Instructions: None<->nAme: first<->Score: 20<->Justification: because") + print(f"exception: {e.exception}") def test_llm_parse_yaml_answer(self): llma = LLMAnswer(Specialist) diff --git a/tests/unit/llm/test_llm_config_object.py b/tests/unit/llm/test_llm_config_object.py index 875f6d4a..325e0229 100644 --- a/tests/unit/llm/test_llm_config_object.py +++ b/tests/unit/llm/test_llm_config_object.py @@ -1,5 +1,5 @@ from council import OpenAILLM, AzureLLM, AnthropicLLM -from council.llm import get_llm_from_config, LLMFallback +from council.llm import get_llm_from_config, LLMFallback, OpenAILLMConfiguration from council.llm.llm_config_object import LLMConfigObject from council.utils import OsEnviron @@ -9,14 +9,21 @@ def test_openai_from_yaml(): filename = get_data_filename(LLModels.OpenAI) - with OsEnviron("OPENAI_API_KEY", "sk-key"): + with OsEnviron("OPENAI_API_KEY", "sk-key"), OsEnviron("OPENAI_API_HOST", "https://openai.com"): actual = LLMConfigObject.from_yaml(filename) assert actual.spec.provider.name == "CML-OpenAI" llm = OpenAILLM.from_config(actual) assert isinstance(llm, OpenAILLM) - assert llm.config.temperature.value == 0.5 - assert llm.config.n.value == 3 + + assert isinstance(llm.config, OpenAILLMConfiguration) + config: OpenAILLMConfiguration = llm.config + assert config.temperature == 0.5 + assert config.n == 3 + assert config.api_host == "https://openai.com" + + llm = get_llm_from_config(filename) + assert isinstance(llm, OpenAILLM) def test_azure_from_yaml(): @@ -28,6 +35,8 @@ def test_azure_from_yaml(): llm = AzureLLM.from_config(actual) assert isinstance(llm, AzureLLM) + llm = get_llm_from_config(filename) + assert isinstance(llm, AzureLLM) def test_anthropic_from_yaml(): @@ -39,6 +48,9 @@ def test_anthropic_from_yaml(): assert isinstance(llm, AnthropicLLM) assert llm.config.top_k.value == 8 + llm = get_llm_from_config(filename) + assert isinstance(llm, AnthropicLLM) + def test_azure_with_openai_fallback_from_yaml(): filename = get_data_filename(LLModels.AzureWithFallback) diff --git a/tests/unit/llm/test_openai_llm_configuration.py b/tests/unit/llm/test_openai_llm_configuration.py index 38d861c7..c95cdbd3 100644 --- a/tests/unit/llm/test_openai_llm_configuration.py +++ b/tests/unit/llm/test_openai_llm_configuration.py @@ -17,7 +17,7 @@ def test_model_override(self): self.assertEqual("gpt-not-default", config.model.value) def test_default(self): - config = OpenAILLMConfiguration(model="gpt-model", api_key="sk-key") + config = OpenAILLMConfiguration(model="gpt-model", api_key="sk-key", api_host="https://api.openai.com") self.assertEqual(0.0, config.temperature.value) self.assertEqual(1, config.n.value) self.assertTrue(config.top_p.is_none()) @@ -26,6 +26,8 @@ def test_default(self): def test_invalid(self): with self.assertRaises(ParameterValueException): - _ = OpenAILLMConfiguration(model="a-gpt-model", api_key="sk-key") + _ = OpenAILLMConfiguration(model="a-gpt-model", api_key="sk-key", api_host="https://api.openai.com") with self.assertRaises(ParameterValueException): - _ = OpenAILLMConfiguration(model="gpt-model", api_key="a-sk-key") + _ = OpenAILLMConfiguration(model="gpt-model", api_key="a-sk-key", api_host="https://api.openai.com") + with self.assertRaises(ParameterValueException): + _ = OpenAILLMConfiguration(model="gpt-model", api_key="sk-key", api_host="api.openai.com") diff --git a/tests/unit/utils/test_parameter.py b/tests/unit/utils/test_parameter.py index 50f7d4ed..80c2079c 100644 --- a/tests/unit/utils/test_parameter.py +++ b/tests/unit/utils/test_parameter.py @@ -1,7 +1,7 @@ import unittest from council.utils import MissingEnvVariableException, EnvVariableValueException, OsEnviron -from council.utils.parameter import ParameterValueException, Parameter +from council.utils.parameter import ParameterValueException, Parameter, Undefined def tv(x: float): @@ -68,3 +68,15 @@ def test_from_env_no_validation(self) -> None: with OsEnviron("TEST_LLM_TEMPERATURE", "9876.54321"): temperature.from_env("TEST_LLM_TEMPERATURE") self.assertEqual(temperature.unwrap(), 9876.54321) + + def test_equal(self) -> None: + temperature1: Parameter[float] = Parameter.float(name="temperature1", required=False, value=12.34) + temperature2: Parameter[float] = Parameter.float(name="temperature2", required=True, value=12.34) + temperature3: Parameter[float] = Parameter.float(name="temperature3", required=True, value=43.21) + + self.assertTrue(temperature1 == temperature2) + self.assertTrue(temperature1 == 12.34) + + self.assertFalse(temperature1 == temperature3) + self.assertFalse(temperature1 == 43.21) + self.assertFalse(temperature1 == Undefined())