diff --git a/council/llm/__init__.py b/council/llm/__init__.py index 51024ba2..4617f623 100644 --- a/council/llm/__init__.py +++ b/council/llm/__init__.py @@ -13,6 +13,7 @@ LLMCostCard, LLMConsumptionCalculatorBase, DefaultLLMConsumptionCalculator, + DefaultLLMConsumptionCalculatorHelper, TokenKind, LLMCostManagerSpec, LLMCostManagerObject, @@ -61,6 +62,9 @@ from .ollama_llm_configuration import OllamaLLMConfiguration from .ollama_llm import OllamaLLM +from .groq_llm_configuration import GroqLLMConfiguration +from .groq_llm import GroqLLM + def get_default_llm(max_retries: Optional[int] = None) -> LLMBase: provider = read_env_str("COUNCIL_DEFAULT_LLM_PROVIDER", default=LLMProviders.OpenAI).unwrap() @@ -77,6 +81,8 @@ def get_default_llm(max_retries: Optional[int] = None) -> LLMBase: llm = GeminiLLM.from_env() elif provider == LLMProviders.Ollama.lower(): llm = OllamaLLM.from_env() + elif provider == LLMProviders.Groq.lower(): + llm = GroqLLM.from_env() if llm is None: raise ValueError(f"Provider {provider} not supported by Council.") @@ -111,5 +117,7 @@ def _build_llm(llm_config: LLMConfigObject) -> LLMBase: return GeminiLLM.from_config(llm_config) elif provider.is_of_kind(LLMProviders.Ollama): return OllamaLLM.from_config(llm_config) + elif provider.is_of_kind(LLMProviders.Groq): + return GroqLLM.from_config(llm_config) raise ValueError(f"Provider `{provider.kind}` not supported by Council") diff --git a/council/llm/chat_gpt_configuration.py b/council/llm/chat_gpt_configuration.py index df2858b6..fc3fef68 100644 --- a/council/llm/chat_gpt_configuration.py +++ b/council/llm/chat_gpt_configuration.py @@ -2,34 +2,7 @@ from typing import Any, Dict, Optional from council.llm.llm_base import LLMConfigurationBase -from council.utils import Parameter - - -def _tv(x: float): - """ - Temperature Validator - Sampling temperature to use, between 0. and 2. - """ - if x < 0.0 or x > 2.0: - raise ValueError("must be in the range [0.0..2.0]") - - -def _pv(x: float): - """ - Penalty Validator - Penalty must be between -2.0 and 2.0 - """ - if x < -2.0 or x > 2.0: - raise ValueError("must be in the range [-2.0..2.0]") - - -def _mtv(x: int): - """ - Max Token Validator - Must be positive - """ - if x <= 0: - raise ValueError("must be positive") +from council.utils import Parameter, penalty_validator, positive_validator, zero_to_one_validator, zero_to_two_validator class ChatGPTConfigurationBase(LLMConfigurationBase, ABC): @@ -38,12 +11,14 @@ class ChatGPTConfigurationBase(LLMConfigurationBase, ABC): """ def __init__(self) -> None: - self._temperature = Parameter.float(name="temperature", required=False, default=0.0, validator=_tv) - self._max_tokens = Parameter.int(name="max_tokens", required=False, validator=_mtv) - self._top_p = Parameter.float(name="top_p", required=False) - self._n = Parameter.int(name="n", required=False, default=1) - self._presence_penalty = Parameter.float(name="presence_penalty", required=False, validator=_pv) - self._frequency_penalty = Parameter.float(name="frequency_penalty", required=False, validator=_pv) + self._temperature = Parameter.float( + name="temperature", required=False, default=0.0, validator=zero_to_two_validator + ) + self._max_tokens = Parameter.int(name="max_tokens", required=False, validator=positive_validator) + self._top_p = Parameter.float(name="top_p", required=False, validator=zero_to_one_validator) + self._n = Parameter.int(name="n", required=False, default=1, validator=positive_validator) + self._presence_penalty = Parameter.float(name="presence_penalty", required=False, validator=penalty_validator) + self._frequency_penalty = Parameter.float(name="frequency_penalty", required=False, validator=penalty_validator) @property def temperature(self) -> Parameter[float]: diff --git a/council/llm/data/groq-costs.yaml b/council/llm/data/groq-costs.yaml new file mode 100644 index 00000000..dc4cde42 --- /dev/null +++ b/council/llm/data/groq-costs.yaml @@ -0,0 +1,54 @@ +kind: LLMCostManager +version: 0.1 +metadata: + name: groq-costs + labels: + provider: Groq + reference: https://groq.com/pricing/ +spec: + default: + description: | + Costs for Groq LLMs + models: + gemma2-9b-it: + input: 0.20 + output: 0.20 + gemma-7b-it: + input: 0.07 + output: 0.07 + llama3-groq-70b-8192-tool-use-preview: + input: 0.89 + output: 0.89 + llama3-groq-8b-8192-tool-use-preview: + input: 0.19 + output: 0.19 + llama-3.1-70b-versatile: + input: 0.59 + output: 0.79 + llama-3.1-8b-instant: + input: 0.05 + output: 0.08 + llama-3.2-1b-preview: + input: 0.04 + output: 0.04 + llama-3.2-3b-preview: + input: 0.06 + output: 0.06 + llama-3.2-11b-vision-preview: + input: 0.18 + output: 0.18 + llama-3.2-90b-vision-preview: + input: 0.90 + output: 0.90 + llama-guard-3-8b: + input: 0.20 + output: 0.20 + llama3-70b-8192: + input: 0.59 + output: 0.79 + llama3-8b-8192: + input: 0.05 + output: 0.08 + mixtral-8x7b-32768: + input: 0.24 + output: 0.24 \ No newline at end of file diff --git a/council/llm/groq_llm.py b/council/llm/groq_llm.py new file mode 100644 index 00000000..63a1c6ee --- /dev/null +++ b/council/llm/groq_llm.py @@ -0,0 +1,137 @@ +from __future__ import annotations + +from typing import Any, Dict, List, Mapping, Optional, Sequence + +from council.contexts import Consumption, LLMContext +from council.llm import ( + DefaultLLMConsumptionCalculatorHelper, + GroqLLMConfiguration, + LLMBase, + LLMConfigObject, + LLMCostCard, + LLMCostManagerObject, + LLMMessage, + LLMMessageRole, + LLMProviders, + LLMResult, +) +from council.utils.utils import DurationManager +from groq import Groq +from groq.types import CompletionUsage +from groq.types.chat import ( + ChatCompletionAssistantMessageParam, + ChatCompletionMessageParam, + ChatCompletionSystemMessageParam, + ChatCompletionUserMessageParam, +) +from groq.types.chat.chat_completion import ChatCompletion, Choice + + +class GroqConsumptionCalculator(DefaultLLMConsumptionCalculatorHelper): + _cost_manager = LLMCostManagerObject.groq() + COSTS: Mapping[str, LLMCostCard] = _cost_manager.get_cost_map("default") + + def __init__(self, model: str) -> None: + super().__init__(model) + + def get_consumptions(self, duration: float, usage: Optional[CompletionUsage]) -> List[Consumption]: + if usage is None: + return self.get_default_consumptions(duration) + + prompt_tokens = usage.prompt_tokens + completion_tokens = usage.completion_tokens + return ( + self.get_base_consumptions(duration, prompt_tokens=prompt_tokens, completion_tokens=completion_tokens) + + self.get_duration_consumptions(usage) + + self.get_cost_consumptions(prompt_tokens=prompt_tokens, completion_tokens=completion_tokens) + ) + + def get_duration_consumptions(self, usage: CompletionUsage) -> List[Consumption]: + """Optional duration consumptions specific to Groq.""" + usage_times: Dict[str, Optional[float]] = { + "queue_time": usage.queue_time, + "prompt_time": usage.prompt_time, + "completion_time": usage.completion_time, + "total_time": usage.total_time, + } + + consumptions = [] + for key, value in usage_times.items(): + if value is not None: + consumptions.append(Consumption.duration(value, f"{self.model}:groq_{key}")) + + return consumptions + + def find_model_costs(self) -> Optional[LLMCostCard]: + return self.COSTS.get(self.model) + + +class GroqLLM(LLMBase[GroqLLMConfiguration]): + def __init__(self, config: GroqLLMConfiguration) -> None: + """ + Initialize a new instance. + + Args: + config(GroqLLMConfiguration): configuration for the instance + """ + super().__init__(name=f"{self.__class__.__name__}", configuration=config) + self._client = Groq(api_key=config.api_key.value) + + def _post_chat_request(self, context: LLMContext, messages: Sequence[LLMMessage], **kwargs: Any) -> LLMResult: + formatted_messages = self._build_messages_payload(messages) + + with DurationManager() as timer: + response = self._client.chat.completions.create( + messages=formatted_messages, + model=self._configuration.model_name(), + **self._configuration.params_to_args(), + **kwargs, + ) + + return LLMResult( + choices=self._to_choices(response.choices), + consumptions=self._to_consumptions(timer.duration, response), + raw_response=response.to_dict(), + ) + + @staticmethod + def _build_messages_payload(messages: Sequence[LLMMessage]) -> List[ChatCompletionMessageParam]: + def _llm_message_to_groq_message(message: LLMMessage) -> ChatCompletionMessageParam: + if message.is_of_role(LLMMessageRole.System): + return ChatCompletionSystemMessageParam(role="system", content=message.content) + elif message.is_of_role(LLMMessageRole.User): + return ChatCompletionUserMessageParam(role="user", content=message.content) + elif message.is_of_role(LLMMessageRole.Assistant): + return ChatCompletionAssistantMessageParam(role="assistant", content=message.content) + + raise ValueError(f"Unknown LLMessage role: `{message.role.value}`") + + return [_llm_message_to_groq_message(message) for message in messages] + + @staticmethod + def _to_choices(choices: List[Choice]) -> List[str]: + return [choice.message.content if choice.message.content is not None else "" for choice in choices] + + @staticmethod + def _to_consumptions(duration: float, response: ChatCompletion) -> Sequence[Consumption]: + calculator = GroqConsumptionCalculator(response.model) + return calculator.get_consumptions(duration, response.usage) + + @staticmethod + def from_env() -> GroqLLM: + """ + Helper function that create a new instance by getting the configuration from environment variables. + + Returns: + GroqLLM + """ + return GroqLLM(GroqLLMConfiguration.from_env()) + + @staticmethod + def from_config(config_object: LLMConfigObject) -> GroqLLM: + provider = config_object.spec.provider + if not provider.is_of_kind(LLMProviders.Groq): + raise ValueError(f"Invalid LLM provider, actual {provider}, expected {LLMProviders.Groq}") + + config = GroqLLMConfiguration.from_spec(config_object.spec) + return GroqLLM(config=config) diff --git a/council/llm/groq_llm_configuration.py b/council/llm/groq_llm_configuration.py new file mode 100644 index 00000000..71a0bbcb --- /dev/null +++ b/council/llm/groq_llm_configuration.py @@ -0,0 +1,147 @@ +from __future__ import annotations + +from typing import Any, Dict, Final, Mapping, Tuple, Type + +from council.utils import ( + Parameter, + not_empty_validator, + penalty_validator, + positive_validator, + read_env_str, + zero_to_one_validator, + zero_to_two_validator, +) + +from . import LLMConfigSpec, LLMConfigurationBase + +_env_var_prefix: Final[str] = "GROQ_" + + +class GroqLLMConfiguration(LLMConfigurationBase): + def __init__(self, model: str, api_key: str) -> None: + """ + Initialize a new instance + + Args: + api_key (str): the api key + model (str): the model name + """ + super().__init__() + self._model = Parameter.string(name="model", required=True, value=model, validator=not_empty_validator) + self._api_key = Parameter.string(name="api_key", required=True, value=api_key, validator=not_empty_validator) + + # https://console.groq.com/docs/api-reference#chat + self._frequency_penalty = Parameter.float(name="frequency_penalty", required=False, validator=penalty_validator) + self._max_tokens = Parameter.int(name="max_tokens", required=False, validator=positive_validator) + self._presence_penalty = Parameter.float(name="presence_penalty", required=False, validator=penalty_validator) + self._seed = Parameter.int(name="seed", required=False) + self._stop = Parameter.string(name="stop", required=False) + self._temperature = Parameter.float( + name="temperature", required=False, default=0.0, validator=zero_to_two_validator + ) + self._top_p = Parameter.float(name="top_p", required=False, validator=zero_to_one_validator) + + def model_name(self) -> str: + return self._model.unwrap() + + @property + def model(self) -> Parameter[str]: + """ + Groq model name + """ + return self._model + + @property + def api_key(self) -> Parameter[str]: + """ + Groq API Key + """ + return self._api_key + + @property + def frequency_penalty(self) -> Parameter[float]: + """ + Number between -2.0 and 2.0. + Positive values penalize new tokens based on their existing frequency in the text so far, + decreasing the model's likelihood to repeat the same line verbatim. + """ + return self._frequency_penalty + + @property + def max_tokens(self) -> Parameter[int]: + """Maximum number of tokens to generate.""" + return self._max_tokens + + @property + def presence_penalty(self) -> Parameter[float]: + """ + Number between -2.0 and 2.0. + Positive values penalize new tokens based on whether they appear in the text so far, + increasing the model's likelihood to talk about new topics. + """ + return self._presence_penalty + + @property + def seed(self) -> Parameter[int]: + """Random seed for generation.""" + return self._seed + + @property + def stop(self) -> Parameter[str]: + """Stop sequence.""" + return self._stop + + @property + def temperature(self) -> Parameter[float]: + """ + What sampling temperature to use, between 0 and 2. + """ + return self._temperature + + @property + def top_p(self) -> Parameter[float]: + """Nucleus sampling threshold.""" + return self._top_p + + def params_to_args(self) -> Dict[str, Any]: + """Convert parameters to options dict""" + return { + "frequency_penalty": self.frequency_penalty.value, + "max_tokens": self.max_tokens.value, + "presence_penalty": self.presence_penalty.value, + "seed": self.seed.value, + "stop": self.stop.value, + "temperature": self.temperature.value, + "top_p": self.top_p.value, + } + + @staticmethod + def from_env() -> GroqLLMConfiguration: + api_key = read_env_str(_env_var_prefix + "API_KEY").unwrap() + model = read_env_str(_env_var_prefix + "LLM_MODEL").unwrap() + config = GroqLLMConfiguration(model=model, api_key=api_key) + return config + + @staticmethod + def from_spec(spec: LLMConfigSpec) -> GroqLLMConfiguration: + api_key = spec.provider.must_get_value("apiKey") + model = spec.provider.must_get_value("model") + config = GroqLLMConfiguration(model=str(model), api_key=str(api_key)) + + if spec.parameters is not None: + param_mapping: Mapping[str, Tuple[Parameter, Type]] = { + "frequencyPenalty": (config.frequency_penalty, float), + "maxTokens": (config.max_tokens, int), + "presencePenalty": (config.presence_penalty, float), + "seed": (config.seed, int), + "stop": (config.stop, list), + "temperature": (config.temperature, float), + "topP": (config.top_p, float), + } + + for key, (param, type_conv) in param_mapping.items(): + value = spec.parameters.get(key) + if value is not None: + param.set(type_conv(value)) + + return config diff --git a/council/llm/llm_config_object.py b/council/llm/llm_config_object.py index 8b2148f5..64b5eb85 100644 --- a/council/llm/llm_config_object.py +++ b/council/llm/llm_config_object.py @@ -15,6 +15,7 @@ class LLMProviders(str, Enum): Anthropic = "anthropicSpec" Gemini = "googleGeminiSpec" Ollama = "ollamaSpec" + Groq = "groqSpec" class LLMProvider: @@ -51,6 +52,9 @@ def from_dict(cls, values: Dict[str, Any]) -> LLMProvider: spec = values.get(LLMProviders.Ollama) if spec is not None: return LLMProvider(name, description, spec, LLMProviders.Ollama) + spec = values.get(LLMProviders.Groq) + if spec is not None: + return LLMProvider(name, description, spec, LLMProviders.Groq) raise ValueError("Unsupported model provider") def to_dict(self) -> Dict[str, Any]: @@ -65,6 +69,8 @@ def to_dict(self) -> Dict[str, Any]: result[LLMProviders.Gemini] = self._specs elif self.is_of_kind(LLMProviders.Ollama): result[LLMProviders.Ollama] = self._specs + elif self.is_of_kind(LLMProviders.Groq): + result[LLMProviders.Groq] = self._specs return result def must_get_value(self, key: str) -> Any: diff --git a/council/llm/llm_cost.py b/council/llm/llm_cost.py index 54feda6c..88347e54 100644 --- a/council/llm/llm_cost.py +++ b/council/llm/llm_cost.py @@ -13,6 +13,7 @@ ANTHROPIC_COSTS_FILENAME: Final[str] = "anthropic-costs.yaml" GEMINI_COSTS_FILENAME: Final[str] = "gemini-costs.yaml" OPENAI_COSTS_FILENAME: Final[str] = "openai-costs.yaml" +GROQ_COSTS_FILENAME: Final[str] = "groq-costs.yaml" class LLMCostCard: @@ -88,6 +89,10 @@ def get_consumptions(self, *args, **kwargs) -> List[Consumption]: """Each calculator will implement with its own parameters.""" pass + def get_default_consumptions(self, duration: float) -> List[Consumption]: + """1 call and specified duration consumptions. To use when token info is not available""" + return [Consumption.call(1, self.model), Consumption.duration(duration, self.model)] + @abc.abstractmethod def find_model_costs(self) -> Optional[LLMCostCard]: """Get LLMCostCard for self to calculate cost consumptions.""" @@ -98,21 +103,7 @@ def filter_zeros(consumptions: List[Consumption]) -> List[Consumption]: return list(filter(lambda consumption: consumption.value > 0, consumptions)) -class DefaultLLMConsumptionCalculator(LLMConsumptionCalculatorBase, abc.ABC): - def get_consumptions(self, duration: float, *, prompt_tokens: int, completion_tokens: int) -> List[Consumption]: - """ - Get default consumptions: - - 1 call - - specified duration - - prompt, completion and total tokens - - corresponding costs if LLMCostCard can be found. - """ - base_consumptions = self.get_base_consumptions( - duration, prompt_tokens=prompt_tokens, completion_tokens=completion_tokens - ) - cost_consumptions = self.get_cost_consumptions(prompt_tokens=prompt_tokens, completion_tokens=completion_tokens) - return base_consumptions + cost_consumptions - +class DefaultLLMConsumptionCalculatorHelper(LLMConsumptionCalculatorBase, abc.ABC): def get_base_consumptions( self, duration: float, *, prompt_tokens: int, completion_tokens: int ) -> List[Consumption]: @@ -137,6 +128,22 @@ def get_cost_consumptions(self, *, prompt_tokens: int, completion_tokens: int) - ] +class DefaultLLMConsumptionCalculator(DefaultLLMConsumptionCalculatorHelper, abc.ABC): + def get_consumptions(self, duration: float, *, prompt_tokens: int, completion_tokens: int) -> List[Consumption]: + """ + Get default consumptions: + - 1 call + - specified duration + - prompt, completion and total tokens + - corresponding costs if LLMCostCard can be found. + """ + base_consumptions = self.get_base_consumptions( + duration, prompt_tokens=prompt_tokens, completion_tokens=completion_tokens + ) + cost_consumptions = self.get_cost_consumptions(prompt_tokens=prompt_tokens, completion_tokens=completion_tokens) + return base_consumptions + cost_consumptions + + class LLMCostManagerSpec(DataObjectSpecBase): def __init__(self, costs: Dict[str, Dict[str, LLMCostCard]]) -> None: """ @@ -197,6 +204,11 @@ def openai(): """Get LLMCostManager for OpenAI models""" return LLMCostManagerObject.from_yaml(os.path.join(DATA_PATH, OPENAI_COSTS_FILENAME)) + @staticmethod + def groq(): + """Get LLMCostManager for Groq models""" + return LLMCostManagerObject.from_yaml(os.path.join(DATA_PATH, GROQ_COSTS_FILENAME)) + def get_cost_map(self, category: str) -> Dict[str, LLMCostCard]: """Get cost mapping {model: LLMCostCard} for a given category""" if category not in self.spec.costs: diff --git a/council/llm/ollama_llm.py b/council/llm/ollama_llm.py index 28b50f73..8ea9b637 100644 --- a/council/llm/ollama_llm.py +++ b/council/llm/ollama_llm.py @@ -32,14 +32,11 @@ def get_consumptions(self, duration: float, response: Mapping[str, Any]) -> List """ return ( - self.get_base_consumptions(duration) + self.get_default_consumptions(duration) + self.get_prompt_consumptions(response) + self.get_duration_consumptions(response) ) - def get_base_consumptions(self, duration: float) -> List[Consumption]: - return [Consumption.call(1, self.model), Consumption.duration(duration, self.model)] - def get_prompt_consumptions(self, response: Mapping[str, Any]) -> List[Consumption]: if not all(key in response for key in ["prompt_eval_count", "eval_count"]): return [] diff --git a/council/utils/__init__.py b/council/utils/__init__.py index 73b1c34b..63407dda 100644 --- a/council/utils/__init__.py +++ b/council/utils/__init__.py @@ -12,7 +12,11 @@ ParameterValueException, Parameter, greater_than_validator, + positive_validator, + a_to_b_validator, zero_to_one_validator, + zero_to_two_validator, + penalty_validator, prefix_validator, prefix_any_validator, not_empty_validator, diff --git a/council/utils/parameter.py b/council/utils/parameter.py index b727b8fb..cc5dca59 100644 --- a/council/utils/parameter.py +++ b/council/utils/parameter.py @@ -17,9 +17,29 @@ def validator(x: int) -> None: return validator +def positive_validator(x: float) -> None: + if x <= 0.0: + raise ValueError("must be positive") + + +def a_to_b_validator(a: float, b: float) -> Validator: + def validator(x: float) -> None: + if x < a or x > b: + raise ValueError(f"must be in the range [{a}..{b}]") + + return validator + + def zero_to_one_validator(x: float) -> None: - if x < 0.0 or x > 1.0: - raise ValueError("must be in the range [0.0..1.0]") + return a_to_b_validator(0.0, 1.0)(x) + + +def zero_to_two_validator(x: float) -> None: + return a_to_b_validator(0.0, 2.0)(x) + + +def penalty_validator(x: float) -> None: + return a_to_b_validator(-2.0, 2.0)(x) def prefix_validator(value: str) -> Validator: diff --git a/docs/data/configs/llm-config-groq.yaml b/docs/data/configs/llm-config-groq.yaml new file mode 100644 index 00000000..61c4e4b0 --- /dev/null +++ b/docs/data/configs/llm-config-groq.yaml @@ -0,0 +1,19 @@ +kind: LLMConfig +version: 0.1 +metadata: + name: a-groq-deployed-model + labels: + provider: Groq +spec: + description: "Model used to do UVW" + provider: + name: Groq + groqSpec: + # https://console.groq.com/docs/models + model: llama-3.2-1b-preview + apiKey: + fromEnvVar: GROQ_API_KEY + parameters: + maxTokens: 128 + seed: 42 + temperature: 0 diff --git a/docs/source/reference/llm.md b/docs/source/reference/llm.md index aee16d57..26fe077c 100644 --- a/docs/source/reference/llm.md +++ b/docs/source/reference/llm.md @@ -23,6 +23,7 @@ Currently supported providers include: - Anthropic's Claude - {class}`~council.llm.AnthropicLLM` - Google's Gemini - {class}`~council.llm.GeminiLLM` - Microsoft's Azure - {class}`~council.llm.AzureLLM` +- Groq - {class}`~council.llm.GroqLLM` - and local models with [ollama](https://ollama.com/) - {class}`~council.llm.OllamaLLM` ```{eval-rst} diff --git a/docs/source/reference/llm/groq_llm.md b/docs/source/reference/llm/groq_llm.md new file mode 100644 index 00000000..94d63e45 --- /dev/null +++ b/docs/source/reference/llm/groq_llm.md @@ -0,0 +1,12 @@ +# GroqLLM + +```{eval-rst} +.. autoclasstree:: council.llm.GroqLLM + :full: + :namespace: council +``` + +```{eval-rst} +.. autoclass:: council.llm.GroqLLM + :member-order: bysource +``` diff --git a/docs/source/reference/llm/groq_llm_configuration.md b/docs/source/reference/llm/groq_llm_configuration.md new file mode 100644 index 00000000..17dfa39b --- /dev/null +++ b/docs/source/reference/llm/groq_llm_configuration.md @@ -0,0 +1,6 @@ +# GroqLLMConfiguration + +```{eval-rst} +.. autoclass:: council.llm.GroqLLMConfiguration + :member-order: bysource +``` diff --git a/docs/source/reference/llm/llm_config_object.md b/docs/source/reference/llm/llm_config_object.md index 740157e6..4fc4788a 100644 --- a/docs/source/reference/llm/llm_config_object.md +++ b/docs/source/reference/llm/llm_config_object.md @@ -52,6 +52,13 @@ Or use `council.llm.get_llm_from_config` to determine provider class automatical :language: yaml ``` +## Groq Config Example + +```{eval-rst} +.. literalinclude:: ../../../data/configs/llm-config-groq.yaml + :language: yaml +``` + ## Azure Config Example ```{eval-rst} diff --git a/requirements.txt b/requirements.txt index cab877d8..0d0e0568 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,6 +13,7 @@ PyYAML~=6.0.1 anthropic~=0.34.2 google-generativeai>=0.7.0 ollama~=0.3.3 +groq~=0.12.0 # Skills ## Google diff --git a/tests/data/groq-llmodel.yaml b/tests/data/groq-llmodel.yaml new file mode 100644 index 00000000..ea712c0c --- /dev/null +++ b/tests/data/groq-llmodel.yaml @@ -0,0 +1,21 @@ +kind: LLMConfig +version: 0.1 +metadata: + name: an-groq-deployed-model + labels: + provider: Groq +spec: + description: "Model used to do DEF" + provider: + name: CML-Groq + groqSpec: + model: llama-3.2-1b-preview + apiKey: + fromEnvVar: GROQ_API_KEY + parameters: + frequencyPenalty: 0.7 + maxTokens: 24 + presencePenalty: -0.4 + seed: 42 + temperature: 0.5 + topP: 0.1 diff --git a/tests/integration/llm/test_groq_llm.py b/tests/integration/llm/test_groq_llm.py new file mode 100644 index 00000000..a61d01ee --- /dev/null +++ b/tests/integration/llm/test_groq_llm.py @@ -0,0 +1,54 @@ +import unittest +import dotenv + +import json + +from council import LLMContext +from council.llm import LLMMessage, GroqLLM +from council.utils import OsEnviron + + +class TestGroqLLM(unittest.TestCase): + @staticmethod + def get_gemma_7b(): + dotenv.load_dotenv() + with OsEnviron("GROQ_LLM_MODEL", "gemma-7b-it"): + return GroqLLM.from_env() + + def test_message(self): + messages = [LLMMessage.user_message("what is the capital of France?")] + llm = self.get_gemma_7b() + + result = llm.post_chat_request(LLMContext.empty(), messages) + assert "Paris" in result.choices[0] + + messages.append(LLMMessage.user_message("give a famous monument of that place")) + result = llm.post_chat_request(LLMContext.empty(), messages) + + assert "Eiffel" in result.choices[0] + + def test_consumptions(self): + messages = [LLMMessage.user_message("Hello how are you?")] + llm = self.get_gemma_7b() + result = llm.post_chat_request(LLMContext.empty(), messages) + + assert len(result.consumptions) == 12 # call, duration, 3 token kinds, 3 cost kinds and 4 groq duration + for consumption in result.consumptions: + assert consumption.kind.startswith("gemma-7b-it") + + def test_max_tokens_param(self): + llm = self.get_gemma_7b() + llm.configuration.temperature.set(0.8) + llm.configuration.max_tokens.set(7) + + messages = [LLMMessage.user_message("Hey how are you?")] + result = llm.post_chat_request(LLMContext.empty(), messages) + print(f"Predicted: {result.first_choice}") + + def test_json_mode(self): + messages = [LLMMessage.user_message("Output a JSON object with the data about RPG character.")] + llm = self.get_gemma_7b() + result = llm.post_chat_request(LLMContext.empty(), messages, response_format={"type": "json_object"}) + + data = json.loads(result.first_choice) + assert isinstance(data, dict) diff --git a/tests/unit/__init__.py b/tests/unit/__init__.py index b7875856..d974ab7f 100644 --- a/tests/unit/__init__.py +++ b/tests/unit/__init__.py @@ -5,6 +5,7 @@ class LLModels: AzureWithFallback: str = "azure-with-fallback-llmodel.yaml" Gemini: str = "gemini-llmodel.yaml" Ollama: str = "ollama-llmodel.yaml" + Groq: str = "groq-llmodel.yaml" class LLMPrompts: diff --git a/tests/unit/llm/test_groq_llm_configuration.py b/tests/unit/llm/test_groq_llm_configuration.py new file mode 100644 index 00000000..759fc958 --- /dev/null +++ b/tests/unit/llm/test_groq_llm_configuration.py @@ -0,0 +1,20 @@ +import unittest +from council.llm import GroqLLMConfiguration +from council.utils import OsEnviron, ParameterValueException + + +class TestGroqLLMConfiguration(unittest.TestCase): + def test_model_override(self): + with OsEnviron("GROQ_API_KEY", "some-key"), OsEnviron("GROQ_LLM_MODEL", "llama-something"): + config = GroqLLMConfiguration.from_env() + self.assertEqual("some-key", config.api_key.value) + self.assertEqual("llama-something", config.model.value) + + def test_default(self): + config = GroqLLMConfiguration(model="llama-something", api_key="some-key") + self.assertEqual(0.0, config.temperature.value) + self.assertTrue(config.top_p.is_none()) + + def test_invalid(self): + with self.assertRaises(ParameterValueException): + _ = GroqLLMConfiguration(model="llama-something", api_key="") diff --git a/tests/unit/llm/test_llm_config_object.py b/tests/unit/llm/test_llm_config_object.py index e500d9d5..e8180843 100644 --- a/tests/unit/llm/test_llm_config_object.py +++ b/tests/unit/llm/test_llm_config_object.py @@ -5,6 +5,8 @@ OpenAIChatGPTConfiguration, OllamaLLM, OllamaLLMConfiguration, + GroqLLM, + GroqLLMConfiguration, ) from council.llm.llm_config_object import LLMConfigObject from council.utils import OsEnviron @@ -108,3 +110,23 @@ def test_ollama_from_yaml(): llm = get_llm_from_config(filename) assert isinstance(llm, OllamaLLM) + + +def test_groq_from_yaml(): + filename = get_data_filename(LLModels.Groq) + + with OsEnviron("GROQ_API_KEY", "a-key"): + actual = LLMConfigObject.from_yaml(filename) + llm = GroqLLM.from_config(actual) + assert isinstance(llm, GroqLLM) + config: GroqLLMConfiguration = llm.configuration + assert config.model.value == "llama-3.2-1b-preview" + assert config.frequency_penalty.value == 0.7 + assert config.max_tokens.value == 24 + assert config.presence_penalty.value == -0.4 + assert config.seed.value == 42 + assert config.temperature.value == 0.5 + assert config.top_p.value == 0.1 + + llm = get_llm_from_config(filename) + assert isinstance(llm, GroqLLM) diff --git a/tests/unit/utils/test_parameter.py b/tests/unit/utils/test_parameter.py index 80c2079c..294c7476 100644 --- a/tests/unit/utils/test_parameter.py +++ b/tests/unit/utils/test_parameter.py @@ -1,15 +1,7 @@ import unittest from council.utils import MissingEnvVariableException, EnvVariableValueException, OsEnviron -from council.utils.parameter import ParameterValueException, Parameter, Undefined - - -def tv(x: float): - """ - Temperature is an Optional float valid between 0.0 and 2.0, default value 0.0 - """ - if x < 0.0 or x > 2.0: - raise ValueError("must be in the range [0.0..2.0]") +from council.utils.parameter import ParameterValueException, Parameter, Undefined, zero_to_two_validator as tv class TestParameter(unittest.TestCase):