From e38e2ea796ae3b39a811ff047e4e370db8554eec Mon Sep 17 00:00:00 2001 From: Dmytro Nikolaiev Date: Fri, 15 Nov 2024 16:11:17 -0500 Subject: [PATCH 01/14] Initial implementation --- council/llm/__init__.py | 7 ++ council/llm/data/groq-costs.yaml | 15 +++ council/llm/groq_llm.py | 132 +++++++++++++++++++++++++ council/llm/groq_llm_configuration.py | 73 ++++++++++++++ council/llm/llm_config_object.py | 6 ++ council/llm/llm_cost.py | 6 ++ requirements.txt | 1 + tests/data/groq-llmodel.yaml | 16 +++ tests/integration/llm/test_groq_llm.py | 34 +++++++ 9 files changed, 290 insertions(+) create mode 100644 council/llm/data/groq-costs.yaml create mode 100644 council/llm/groq_llm.py create mode 100644 council/llm/groq_llm_configuration.py create mode 100644 tests/data/groq-llmodel.yaml create mode 100644 tests/integration/llm/test_groq_llm.py diff --git a/council/llm/__init__.py b/council/llm/__init__.py index 51024ba2..8c893684 100644 --- a/council/llm/__init__.py +++ b/council/llm/__init__.py @@ -61,6 +61,9 @@ from .ollama_llm_configuration import OllamaLLMConfiguration from .ollama_llm import OllamaLLM +from .groq_llm_configuration import GroqLLMConfiguration +from .groq_llm import GroqLLM + def get_default_llm(max_retries: Optional[int] = None) -> LLMBase: provider = read_env_str("COUNCIL_DEFAULT_LLM_PROVIDER", default=LLMProviders.OpenAI).unwrap() @@ -77,6 +80,8 @@ def get_default_llm(max_retries: Optional[int] = None) -> LLMBase: llm = GeminiLLM.from_env() elif provider == LLMProviders.Ollama.lower(): llm = OllamaLLM.from_env() + elif provider == LLMProviders.Groq.lower(): + llm = GroqLLM.from_env() if llm is None: raise ValueError(f"Provider {provider} not supported by Council.") @@ -111,5 +116,7 @@ def _build_llm(llm_config: LLMConfigObject) -> LLMBase: return GeminiLLM.from_config(llm_config) elif provider.is_of_kind(LLMProviders.Ollama): return OllamaLLM.from_config(llm_config) + elif provider.is_of_kind(LLMProviders.Groq): + return GroqLLM.from_config(llm_config) raise ValueError(f"Provider `{provider.kind}` not supported by Council") diff --git a/council/llm/data/groq-costs.yaml b/council/llm/data/groq-costs.yaml new file mode 100644 index 00000000..942fadd7 --- /dev/null +++ b/council/llm/data/groq-costs.yaml @@ -0,0 +1,15 @@ +kind: LLMCostManager +version: 0.1 +metadata: + name: groq-costs + labels: + provider: Groq + reference: https://groq.com/pricing/ +spec: + default: + description: | + Costs for Groq LLMs + models: + llama-3.2-1b-preview: + input: 0.04 + output: 0.04 diff --git a/council/llm/groq_llm.py b/council/llm/groq_llm.py new file mode 100644 index 00000000..06ac0358 --- /dev/null +++ b/council/llm/groq_llm.py @@ -0,0 +1,132 @@ +from __future__ import annotations + +from typing import Any, List, Mapping, Optional, Sequence + +from council.contexts import Consumption, LLMContext +from council.llm import ( + DefaultLLMConsumptionCalculator, + GroqLLMConfiguration, + LLMBase, + LLMConfigObject, + LLMCostCard, + LLMCostManagerObject, + LLMMessage, + LLMMessageRole, + LLMProviders, + LLMResult, +) +from council.utils.utils import DurationManager +from groq import Groq +from groq.types import CompletionUsage +from groq.types.chat import ( + ChatCompletionAssistantMessageParam, + ChatCompletionMessageParam, + ChatCompletionSystemMessageParam, + ChatCompletionUserMessageParam, +) +from groq.types.chat.chat_completion import ChatCompletion, Choice + + +class GroqConsumptionCalculator(DefaultLLMConsumptionCalculator): + _cost_manager = LLMCostManagerObject.groq() + COSTS: Mapping[str, LLMCostCard] = _cost_manager.get_cost_map("default") + + def __init__(self, model: str) -> None: + super().__init__(model) + + # TODO: naming + def get_groq_consumptions(self, duration: float, usage: Optional[CompletionUsage]) -> List[Consumption]: + if usage is None: + return [Consumption.call(1, self.model), Consumption.duration(duration, self.model)] + + prompt_tokens = usage.prompt_tokens + completion_tokens = usage.completion_tokens + return ( + self.get_base_consumptions(duration, prompt_tokens=prompt_tokens, completion_tokens=completion_tokens) + + self.get_duration_consumptions(usage) + + self.get_cost_consumptions(prompt_tokens=prompt_tokens, completion_tokens=completion_tokens) + ) + + def get_duration_consumptions(self, usage: CompletionUsage) -> List[Consumption]: + """Optional duration consumptions specific to Groq.""" + consumptions = [] + if usage.queue_time is not None: + consumptions.append(Consumption.duration(usage.queue_time, f"{self.model}:groq_queue_time")) + if usage.prompt_time is not None: + consumptions.append(Consumption.duration(usage.prompt_time, f"{self.model}:groq_prompt_time")) + if usage.completion_time is not None: + consumptions.append(Consumption.duration(usage.completion_time, f"{self.model}:groq_completion_time")) + if usage.total_time is not None: + consumptions.append(Consumption.duration(usage.total_time, f"{self.model}:groq_total_time")) + return consumptions + + def find_model_costs(self) -> Optional[LLMCostCard]: + return self.COSTS.get(self.model) + + +class GroqLLM(LLMBase[GroqLLMConfiguration]): + def __init__(self, config: GroqLLMConfiguration) -> None: + """ + Initialize a new instance. + + Args: + config(GroqLLMConfiguration): configuration for the instance + """ + super().__init__(name=f"{self.__class__.__name__}", configuration=config) + self._client = Groq(api_key=config.api_key.value) + + def _post_chat_request(self, context: LLMContext, messages: Sequence[LLMMessage], **kwargs: Any) -> LLMResult: + formatted_messages = self._build_messages_payload(messages) + + with DurationManager() as timer: + response = self._client.chat.completions.create( + messages=formatted_messages, + model=self._configuration.model_name(), + temperature=self._configuration.temperature.value, + ) + + return LLMResult( + choices=self._to_choices(response.choices), consumptions=self._to_consumptions(timer.duration, response) + ) + + @staticmethod + def _build_messages_payload(messages: Sequence[LLMMessage]) -> List[ChatCompletionMessageParam]: + def _llm_message_to_groq_message(message: LLMMessage) -> ChatCompletionMessageParam: + if message.is_of_role(LLMMessageRole.System): + return ChatCompletionSystemMessageParam(role="system", content=message.content) + elif message.is_of_role(LLMMessageRole.User): + return ChatCompletionUserMessageParam(role="user", content=message.content) + elif message.is_of_role(LLMMessageRole.Assistant): + return ChatCompletionAssistantMessageParam(role="assistant", content=message.content) + + raise ValueError(f"Unknown LLMessage role: `{message.role.value}`") + + return [_llm_message_to_groq_message(message) for message in messages] + + @staticmethod + def _to_choices(choices: List[Choice]) -> List[str]: + return [choice.message.content if choice.message.content is not None else "" for choice in choices] + + @staticmethod + def _to_consumptions(duration: float, response: ChatCompletion) -> Sequence[Consumption]: + calculator = GroqConsumptionCalculator(response.model) + return calculator.get_groq_consumptions(duration, response.usage) + + @staticmethod + def from_env() -> GroqLLM: + """ + Helper function that create a new instance by getting the configuration from environment variables. + + Returns: + GroqLLM + """ + return GroqLLM(GroqLLMConfiguration.from_env()) + + @staticmethod + def from_config(config_object: LLMConfigObject) -> GroqLLM: + provider = config_object.spec.provider + if not provider.is_of_kind(LLMProviders.Groq): + raise ValueError(f"Invalid LLM provider, actual {provider}, expected {LLMProviders.Groq}") + + config = GroqLLMConfiguration.from_spec(config_object.spec) + return GroqLLM(config=config) diff --git a/council/llm/groq_llm_configuration.py b/council/llm/groq_llm_configuration.py new file mode 100644 index 00000000..63893e6c --- /dev/null +++ b/council/llm/groq_llm_configuration.py @@ -0,0 +1,73 @@ +from __future__ import annotations + +from typing import Any, Final, Optional + +from council.utils import Parameter, not_empty_validator, read_env_str, zero_to_one_validator + +from . import LLMConfigSpec, LLMConfigurationBase + +_env_var_prefix: Final[str] = "GROQ_" + + +class GroqLLMConfiguration(LLMConfigurationBase): + def __init__(self, model: str, api_key: str) -> None: + """ + Initialize a new instance + + Args: + api_key (str): the api key + model (str): the model name + """ + super().__init__() + self._model = Parameter.string(name="model", required=True, value=model, validator=not_empty_validator) + self._api_key = Parameter.string(name="api_key", required=True, value=api_key, validator=not_empty_validator) + self._temperature = Parameter.float( + name="temperature", required=False, default=0.0, validator=zero_to_one_validator + ) + + def model_name(self) -> str: + return self._model.unwrap() + + @property + def model(self) -> Parameter[str]: + """ + Groq model name + """ + return self._model + + @property + def api_key(self) -> Parameter[str]: + """ + Groq API Key + """ + return self._api_key + + @property + def temperature(self) -> Parameter[float]: + """ + Amount of randomness injected into the response. + Ranges from 0 to 1. + """ + return self._temperature + + # TODO: more parameters + + @staticmethod + def from_env() -> GroqLLMConfiguration: + api_key = read_env_str(_env_var_prefix + "API_KEY").unwrap() + model = read_env_str(_env_var_prefix + "LLM_MODEL").unwrap() + config = GroqLLMConfiguration(model=model, api_key=api_key) + return config + + @staticmethod + def from_spec(spec: LLMConfigSpec) -> GroqLLMConfiguration: + api_key = spec.provider.must_get_value("apiKey") + model = spec.provider.must_get_value("model") + config = GroqLLMConfiguration(model=str(model), api_key=str(api_key)) + + if spec.parameters is not None: + value: Optional[Any] = spec.parameters.get("temperature", None) + if value is not None: + config.temperature.set(float(value)) + + return config diff --git a/council/llm/llm_config_object.py b/council/llm/llm_config_object.py index 8b2148f5..64b5eb85 100644 --- a/council/llm/llm_config_object.py +++ b/council/llm/llm_config_object.py @@ -15,6 +15,7 @@ class LLMProviders(str, Enum): Anthropic = "anthropicSpec" Gemini = "googleGeminiSpec" Ollama = "ollamaSpec" + Groq = "groqSpec" class LLMProvider: @@ -51,6 +52,9 @@ def from_dict(cls, values: Dict[str, Any]) -> LLMProvider: spec = values.get(LLMProviders.Ollama) if spec is not None: return LLMProvider(name, description, spec, LLMProviders.Ollama) + spec = values.get(LLMProviders.Groq) + if spec is not None: + return LLMProvider(name, description, spec, LLMProviders.Groq) raise ValueError("Unsupported model provider") def to_dict(self) -> Dict[str, Any]: @@ -65,6 +69,8 @@ def to_dict(self) -> Dict[str, Any]: result[LLMProviders.Gemini] = self._specs elif self.is_of_kind(LLMProviders.Ollama): result[LLMProviders.Ollama] = self._specs + elif self.is_of_kind(LLMProviders.Groq): + result[LLMProviders.Groq] = self._specs return result def must_get_value(self, key: str) -> Any: diff --git a/council/llm/llm_cost.py b/council/llm/llm_cost.py index 54feda6c..f54c381f 100644 --- a/council/llm/llm_cost.py +++ b/council/llm/llm_cost.py @@ -13,6 +13,7 @@ ANTHROPIC_COSTS_FILENAME: Final[str] = "anthropic-costs.yaml" GEMINI_COSTS_FILENAME: Final[str] = "gemini-costs.yaml" OPENAI_COSTS_FILENAME: Final[str] = "openai-costs.yaml" +GROQ_COSTS_FILENAME: Final[str] = "groq-costs.yaml" class LLMCostCard: @@ -197,6 +198,11 @@ def openai(): """Get LLMCostManager for OpenAI models""" return LLMCostManagerObject.from_yaml(os.path.join(DATA_PATH, OPENAI_COSTS_FILENAME)) + @staticmethod + def groq(): + """Get LLMCostManager for Groq models""" + return LLMCostManagerObject.from_yaml(os.path.join(DATA_PATH, GROQ_COSTS_FILENAME)) + def get_cost_map(self, category: str) -> Dict[str, LLMCostCard]: """Get cost mapping {model: LLMCostCard} for a given category""" if category not in self.spec.costs: diff --git a/requirements.txt b/requirements.txt index cab877d8..0d0e0568 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,6 +13,7 @@ PyYAML~=6.0.1 anthropic~=0.34.2 google-generativeai>=0.7.0 ollama~=0.3.3 +groq~=0.12.0 # Skills ## Google diff --git a/tests/data/groq-llmodel.yaml b/tests/data/groq-llmodel.yaml new file mode 100644 index 00000000..8477bca0 --- /dev/null +++ b/tests/data/groq-llmodel.yaml @@ -0,0 +1,16 @@ +kind: LLMConfig +version: 0.1 +metadata: + name: an-groq-deployed-model + labels: + provider: Groq +spec: + description: "Model used to do DEF" + provider: + name: CML-Groq + groqSpec: + model: llama-3.2-1b-preview + apiKey: + fromEnvVar: GROQ_API_KEY + parameters: + temperature: 0.5 diff --git a/tests/integration/llm/test_groq_llm.py b/tests/integration/llm/test_groq_llm.py new file mode 100644 index 00000000..38088a2c --- /dev/null +++ b/tests/integration/llm/test_groq_llm.py @@ -0,0 +1,34 @@ +import unittest +import dotenv + +from council import LLMContext +from council.llm import LLMMessage, GroqLLM +from council.utils import OsEnviron + + +class TestGroqLLM(unittest.TestCase): + def test_message(self): + messages = [LLMMessage.user_message("what is the capital of France?")] + dotenv.load_dotenv() + with OsEnviron("GROQ_LLM_MODEL", "llama-3.2-1b-preview"): + instance = GroqLLM.from_env() + context = LLMContext.empty() + result = instance.post_chat_request(context, messages) + + assert "Paris" in result.choices[0] + + messages.append(LLMMessage.user_message("give a famous monument of that place")) + result = instance.post_chat_request(context, messages) + + assert "Eiffel" in result.choices[0] + + def test_consumptions(self): + messages = [LLMMessage.user_message("Hello how are you?")] + dotenv.load_dotenv() + with OsEnviron("GROQ_LLM_MODEL", "llama-3.2-1b-preview"): + instance = GroqLLM.from_env() + result = instance.post_chat_request(LLMContext.empty(), messages) + + assert len(result.consumptions) == 12 # call, duration, 3 token kinds, 3 cost kinds and 4 groq duration + for consumption in result.consumptions: + assert consumption.kind.startswith("llama-3.2-1b-preview") From 5e680eceb5acf321c84e02dae1bbe5622bfc6207 Mon Sep 17 00:00:00 2001 From: Dmytro Nikolaiev Date: Mon, 18 Nov 2024 14:50:46 -0500 Subject: [PATCH 02/14] Update --- tests/integration/llm/test_groq_llm.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/tests/integration/llm/test_groq_llm.py b/tests/integration/llm/test_groq_llm.py index 38088a2c..a88efe42 100644 --- a/tests/integration/llm/test_groq_llm.py +++ b/tests/integration/llm/test_groq_llm.py @@ -12,13 +12,12 @@ def test_message(self): dotenv.load_dotenv() with OsEnviron("GROQ_LLM_MODEL", "llama-3.2-1b-preview"): instance = GroqLLM.from_env() - context = LLMContext.empty() - result = instance.post_chat_request(context, messages) - assert "Paris" in result.choices[0] + result = instance.post_chat_request(LLMContext.empty(), messages) + assert "Paris" in result.choices[0] messages.append(LLMMessage.user_message("give a famous monument of that place")) - result = instance.post_chat_request(context, messages) + result = instance.post_chat_request(LLMContext.empty(), messages) assert "Eiffel" in result.choices[0] From d9bb0299189808806976586c41c99c1d887a9537 Mon Sep 17 00:00:00 2001 From: Dmytro Nikolaiev Date: Mon, 18 Nov 2024 15:23:15 -0500 Subject: [PATCH 03/14] Add raw response --- council/llm/groq_llm.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/council/llm/groq_llm.py b/council/llm/groq_llm.py index 06ac0358..cc529337 100644 --- a/council/llm/groq_llm.py +++ b/council/llm/groq_llm.py @@ -86,7 +86,9 @@ def _post_chat_request(self, context: LLMContext, messages: Sequence[LLMMessage] ) return LLMResult( - choices=self._to_choices(response.choices), consumptions=self._to_consumptions(timer.duration, response) + choices=self._to_choices(response.choices), + consumptions=self._to_consumptions(timer.duration, response), + raw_response=response.to_dict(), ) @staticmethod From 229de8653ceca172ac3c8fb7ff1b27c423846f4d Mon Sep 17 00:00:00 2001 From: Dmytro Nikolaiev Date: Mon, 18 Nov 2024 15:24:38 -0500 Subject: [PATCH 04/14] Implement get_default_consumptions() --- council/llm/groq_llm.py | 2 +- council/llm/llm_cost.py | 4 ++++ council/llm/ollama_llm.py | 5 +---- 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/council/llm/groq_llm.py b/council/llm/groq_llm.py index cc529337..8c116e07 100644 --- a/council/llm/groq_llm.py +++ b/council/llm/groq_llm.py @@ -37,7 +37,7 @@ def __init__(self, model: str) -> None: # TODO: naming def get_groq_consumptions(self, duration: float, usage: Optional[CompletionUsage]) -> List[Consumption]: if usage is None: - return [Consumption.call(1, self.model), Consumption.duration(duration, self.model)] + return self.get_default_consumptions(duration) prompt_tokens = usage.prompt_tokens completion_tokens = usage.completion_tokens diff --git a/council/llm/llm_cost.py b/council/llm/llm_cost.py index f54c381f..51322ef6 100644 --- a/council/llm/llm_cost.py +++ b/council/llm/llm_cost.py @@ -89,6 +89,10 @@ def get_consumptions(self, *args, **kwargs) -> List[Consumption]: """Each calculator will implement with its own parameters.""" pass + def get_default_consumptions(self, duration: float) -> List[Consumption]: + """1 call and specified duration consumptions. To use when token info is not available""" + return [Consumption.call(1, self.model), Consumption.duration(duration, self.model)] + @abc.abstractmethod def find_model_costs(self) -> Optional[LLMCostCard]: """Get LLMCostCard for self to calculate cost consumptions.""" diff --git a/council/llm/ollama_llm.py b/council/llm/ollama_llm.py index 28b50f73..8ea9b637 100644 --- a/council/llm/ollama_llm.py +++ b/council/llm/ollama_llm.py @@ -32,14 +32,11 @@ def get_consumptions(self, duration: float, response: Mapping[str, Any]) -> List """ return ( - self.get_base_consumptions(duration) + self.get_default_consumptions(duration) + self.get_prompt_consumptions(response) + self.get_duration_consumptions(response) ) - def get_base_consumptions(self, duration: float) -> List[Consumption]: - return [Consumption.call(1, self.model), Consumption.duration(duration, self.model)] - def get_prompt_consumptions(self, response: Mapping[str, Any]) -> List[Consumption]: if not all(key in response for key in ["prompt_eval_count", "eval_count"]): return [] From efdf975be4cb952be9afd23223d689e7f727c348 Mon Sep 17 00:00:00 2001 From: Dmytro Nikolaiev Date: Mon, 18 Nov 2024 15:25:34 -0500 Subject: [PATCH 05/14] Update docs --- docs/data/configs/llm-config-groq.yaml | 18 ++++++++++++++++++ docs/source/reference/llm.md | 1 + docs/source/reference/llm/llm_config_object.md | 7 +++++++ 3 files changed, 26 insertions(+) create mode 100644 docs/data/configs/llm-config-groq.yaml diff --git a/docs/data/configs/llm-config-groq.yaml b/docs/data/configs/llm-config-groq.yaml new file mode 100644 index 00000000..ac5d4076 --- /dev/null +++ b/docs/data/configs/llm-config-groq.yaml @@ -0,0 +1,18 @@ +kind: LLMConfig +version: 0.1 +metadata: + name: a-groq-deployed-model + labels: + provider: Groq +spec: + description: "Model used to do UVW" + provider: + name: Groq + groqSpec: + # https://console.groq.com/docs/models + model: llama-3.2-1b-preview + apiKey: + fromEnvVar: GROQ_API_KEY + parameters: + n: 1 + temperature: 0 diff --git a/docs/source/reference/llm.md b/docs/source/reference/llm.md index aee16d57..26fe077c 100644 --- a/docs/source/reference/llm.md +++ b/docs/source/reference/llm.md @@ -23,6 +23,7 @@ Currently supported providers include: - Anthropic's Claude - {class}`~council.llm.AnthropicLLM` - Google's Gemini - {class}`~council.llm.GeminiLLM` - Microsoft's Azure - {class}`~council.llm.AzureLLM` +- Groq - {class}`~council.llm.GroqLLM` - and local models with [ollama](https://ollama.com/) - {class}`~council.llm.OllamaLLM` ```{eval-rst} diff --git a/docs/source/reference/llm/llm_config_object.md b/docs/source/reference/llm/llm_config_object.md index 740157e6..4fc4788a 100644 --- a/docs/source/reference/llm/llm_config_object.md +++ b/docs/source/reference/llm/llm_config_object.md @@ -52,6 +52,13 @@ Or use `council.llm.get_llm_from_config` to determine provider class automatical :language: yaml ``` +## Groq Config Example + +```{eval-rst} +.. literalinclude:: ../../../data/configs/llm-config-groq.yaml + :language: yaml +``` + ## Azure Config Example ```{eval-rst} From 2b4ce58a7ec4a03fc29f986e411292efac1cdd6f Mon Sep 17 00:00:00 2001 From: Dmytro Nikolaiev Date: Mon, 18 Nov 2024 16:01:29 -0500 Subject: [PATCH 06/14] Minor validators refactoring --- council/llm/chat_gpt_configuration.py | 43 ++++++--------------------- council/utils/__init__.py | 4 +++ council/utils/parameter.py | 24 +++++++++++++-- tests/unit/utils/test_parameter.py | 10 +------ 4 files changed, 36 insertions(+), 45 deletions(-) diff --git a/council/llm/chat_gpt_configuration.py b/council/llm/chat_gpt_configuration.py index df2858b6..fc3fef68 100644 --- a/council/llm/chat_gpt_configuration.py +++ b/council/llm/chat_gpt_configuration.py @@ -2,34 +2,7 @@ from typing import Any, Dict, Optional from council.llm.llm_base import LLMConfigurationBase -from council.utils import Parameter - - -def _tv(x: float): - """ - Temperature Validator - Sampling temperature to use, between 0. and 2. - """ - if x < 0.0 or x > 2.0: - raise ValueError("must be in the range [0.0..2.0]") - - -def _pv(x: float): - """ - Penalty Validator - Penalty must be between -2.0 and 2.0 - """ - if x < -2.0 or x > 2.0: - raise ValueError("must be in the range [-2.0..2.0]") - - -def _mtv(x: int): - """ - Max Token Validator - Must be positive - """ - if x <= 0: - raise ValueError("must be positive") +from council.utils import Parameter, penalty_validator, positive_validator, zero_to_one_validator, zero_to_two_validator class ChatGPTConfigurationBase(LLMConfigurationBase, ABC): @@ -38,12 +11,14 @@ class ChatGPTConfigurationBase(LLMConfigurationBase, ABC): """ def __init__(self) -> None: - self._temperature = Parameter.float(name="temperature", required=False, default=0.0, validator=_tv) - self._max_tokens = Parameter.int(name="max_tokens", required=False, validator=_mtv) - self._top_p = Parameter.float(name="top_p", required=False) - self._n = Parameter.int(name="n", required=False, default=1) - self._presence_penalty = Parameter.float(name="presence_penalty", required=False, validator=_pv) - self._frequency_penalty = Parameter.float(name="frequency_penalty", required=False, validator=_pv) + self._temperature = Parameter.float( + name="temperature", required=False, default=0.0, validator=zero_to_two_validator + ) + self._max_tokens = Parameter.int(name="max_tokens", required=False, validator=positive_validator) + self._top_p = Parameter.float(name="top_p", required=False, validator=zero_to_one_validator) + self._n = Parameter.int(name="n", required=False, default=1, validator=positive_validator) + self._presence_penalty = Parameter.float(name="presence_penalty", required=False, validator=penalty_validator) + self._frequency_penalty = Parameter.float(name="frequency_penalty", required=False, validator=penalty_validator) @property def temperature(self) -> Parameter[float]: diff --git a/council/utils/__init__.py b/council/utils/__init__.py index 73b1c34b..63407dda 100644 --- a/council/utils/__init__.py +++ b/council/utils/__init__.py @@ -12,7 +12,11 @@ ParameterValueException, Parameter, greater_than_validator, + positive_validator, + a_to_b_validator, zero_to_one_validator, + zero_to_two_validator, + penalty_validator, prefix_validator, prefix_any_validator, not_empty_validator, diff --git a/council/utils/parameter.py b/council/utils/parameter.py index b727b8fb..cc5dca59 100644 --- a/council/utils/parameter.py +++ b/council/utils/parameter.py @@ -17,9 +17,29 @@ def validator(x: int) -> None: return validator +def positive_validator(x: float) -> None: + if x <= 0.0: + raise ValueError("must be positive") + + +def a_to_b_validator(a: float, b: float) -> Validator: + def validator(x: float) -> None: + if x < a or x > b: + raise ValueError(f"must be in the range [{a}..{b}]") + + return validator + + def zero_to_one_validator(x: float) -> None: - if x < 0.0 or x > 1.0: - raise ValueError("must be in the range [0.0..1.0]") + return a_to_b_validator(0.0, 1.0)(x) + + +def zero_to_two_validator(x: float) -> None: + return a_to_b_validator(0.0, 2.0)(x) + + +def penalty_validator(x: float) -> None: + return a_to_b_validator(-2.0, 2.0)(x) def prefix_validator(value: str) -> Validator: diff --git a/tests/unit/utils/test_parameter.py b/tests/unit/utils/test_parameter.py index 80c2079c..294c7476 100644 --- a/tests/unit/utils/test_parameter.py +++ b/tests/unit/utils/test_parameter.py @@ -1,15 +1,7 @@ import unittest from council.utils import MissingEnvVariableException, EnvVariableValueException, OsEnviron -from council.utils.parameter import ParameterValueException, Parameter, Undefined - - -def tv(x: float): - """ - Temperature is an Optional float valid between 0.0 and 2.0, default value 0.0 - """ - if x < 0.0 or x > 2.0: - raise ValueError("must be in the range [0.0..2.0]") +from council.utils.parameter import ParameterValueException, Parameter, Undefined, zero_to_two_validator as tv class TestParameter(unittest.TestCase): From 38d0be7eba23f4ba13b10f91bb2f8681a5712d3b Mon Sep 17 00:00:00 2001 From: Dmytro Nikolaiev Date: Mon, 18 Nov 2024 16:03:14 -0500 Subject: [PATCH 07/14] Add config unit tests --- tests/data/groq-llmodel.yaml | 5 +++++ tests/unit/__init__.py | 1 + tests/unit/llm/test_groq_llm_configuration.py | 20 +++++++++++++++++ tests/unit/llm/test_llm_config_object.py | 22 +++++++++++++++++++ 4 files changed, 48 insertions(+) create mode 100644 tests/unit/llm/test_groq_llm_configuration.py diff --git a/tests/data/groq-llmodel.yaml b/tests/data/groq-llmodel.yaml index 8477bca0..ea712c0c 100644 --- a/tests/data/groq-llmodel.yaml +++ b/tests/data/groq-llmodel.yaml @@ -13,4 +13,9 @@ spec: apiKey: fromEnvVar: GROQ_API_KEY parameters: + frequencyPenalty: 0.7 + maxTokens: 24 + presencePenalty: -0.4 + seed: 42 temperature: 0.5 + topP: 0.1 diff --git a/tests/unit/__init__.py b/tests/unit/__init__.py index b7875856..d974ab7f 100644 --- a/tests/unit/__init__.py +++ b/tests/unit/__init__.py @@ -5,6 +5,7 @@ class LLModels: AzureWithFallback: str = "azure-with-fallback-llmodel.yaml" Gemini: str = "gemini-llmodel.yaml" Ollama: str = "ollama-llmodel.yaml" + Groq: str = "groq-llmodel.yaml" class LLMPrompts: diff --git a/tests/unit/llm/test_groq_llm_configuration.py b/tests/unit/llm/test_groq_llm_configuration.py new file mode 100644 index 00000000..759fc958 --- /dev/null +++ b/tests/unit/llm/test_groq_llm_configuration.py @@ -0,0 +1,20 @@ +import unittest +from council.llm import GroqLLMConfiguration +from council.utils import OsEnviron, ParameterValueException + + +class TestGroqLLMConfiguration(unittest.TestCase): + def test_model_override(self): + with OsEnviron("GROQ_API_KEY", "some-key"), OsEnviron("GROQ_LLM_MODEL", "llama-something"): + config = GroqLLMConfiguration.from_env() + self.assertEqual("some-key", config.api_key.value) + self.assertEqual("llama-something", config.model.value) + + def test_default(self): + config = GroqLLMConfiguration(model="llama-something", api_key="some-key") + self.assertEqual(0.0, config.temperature.value) + self.assertTrue(config.top_p.is_none()) + + def test_invalid(self): + with self.assertRaises(ParameterValueException): + _ = GroqLLMConfiguration(model="llama-something", api_key="") diff --git a/tests/unit/llm/test_llm_config_object.py b/tests/unit/llm/test_llm_config_object.py index e500d9d5..e8180843 100644 --- a/tests/unit/llm/test_llm_config_object.py +++ b/tests/unit/llm/test_llm_config_object.py @@ -5,6 +5,8 @@ OpenAIChatGPTConfiguration, OllamaLLM, OllamaLLMConfiguration, + GroqLLM, + GroqLLMConfiguration, ) from council.llm.llm_config_object import LLMConfigObject from council.utils import OsEnviron @@ -108,3 +110,23 @@ def test_ollama_from_yaml(): llm = get_llm_from_config(filename) assert isinstance(llm, OllamaLLM) + + +def test_groq_from_yaml(): + filename = get_data_filename(LLModels.Groq) + + with OsEnviron("GROQ_API_KEY", "a-key"): + actual = LLMConfigObject.from_yaml(filename) + llm = GroqLLM.from_config(actual) + assert isinstance(llm, GroqLLM) + config: GroqLLMConfiguration = llm.configuration + assert config.model.value == "llama-3.2-1b-preview" + assert config.frequency_penalty.value == 0.7 + assert config.max_tokens.value == 24 + assert config.presence_penalty.value == -0.4 + assert config.seed.value == 42 + assert config.temperature.value == 0.5 + assert config.top_p.value == 0.1 + + llm = get_llm_from_config(filename) + assert isinstance(llm, GroqLLM) From ef5c83135030c48bbf9d4ad3976017f915c46ee6 Mon Sep 17 00:00:00 2001 From: Dmytro Nikolaiev Date: Mon, 18 Nov 2024 16:03:42 -0500 Subject: [PATCH 08/14] Fix n in example yaml --- docs/data/configs/llm-config-groq.yaml | 1 - 1 file changed, 1 deletion(-) diff --git a/docs/data/configs/llm-config-groq.yaml b/docs/data/configs/llm-config-groq.yaml index ac5d4076..c14c19a8 100644 --- a/docs/data/configs/llm-config-groq.yaml +++ b/docs/data/configs/llm-config-groq.yaml @@ -14,5 +14,4 @@ spec: apiKey: fromEnvVar: GROQ_API_KEY parameters: - n: 1 temperature: 0 From a96be04b57b33846d74090f0d065a64f57fc994b Mon Sep 17 00:00:00 2001 From: Dmytro Nikolaiev Date: Mon, 18 Nov 2024 16:03:57 -0500 Subject: [PATCH 09/14] Add more parameters --- council/llm/groq_llm_configuration.py | 87 ++++++++++++++++++++++++--- 1 file changed, 77 insertions(+), 10 deletions(-) diff --git a/council/llm/groq_llm_configuration.py b/council/llm/groq_llm_configuration.py index 63893e6c..44b55cbf 100644 --- a/council/llm/groq_llm_configuration.py +++ b/council/llm/groq_llm_configuration.py @@ -1,8 +1,16 @@ from __future__ import annotations -from typing import Any, Final, Optional - -from council.utils import Parameter, not_empty_validator, read_env_str, zero_to_one_validator +from typing import Final, List, Mapping, Optional, Tuple, Type + +from council.utils import ( + Parameter, + not_empty_validator, + penalty_validator, + positive_validator, + read_env_str, + zero_to_one_validator, + zero_to_two_validator, +) from . import LLMConfigSpec, LLMConfigurationBase @@ -21,9 +29,17 @@ def __init__(self, model: str, api_key: str) -> None: super().__init__() self._model = Parameter.string(name="model", required=True, value=model, validator=not_empty_validator) self._api_key = Parameter.string(name="api_key", required=True, value=api_key, validator=not_empty_validator) + + # https://console.groq.com/docs/api-reference#chat + self._frequency_penalty = Parameter.float(name="frequency_penalty", required=False, validator=penalty_validator) + self._max_tokens = Parameter.int(name="max_tokens", required=False, validator=positive_validator) + self._presence_penalty = Parameter.float(name="presence_penalty", required=False, validator=penalty_validator) + self._seed = Parameter.int(name="seed", required=False) + self._stop = Parameter.string(name="stop", required=False) self._temperature = Parameter.float( - name="temperature", required=False, default=0.0, validator=zero_to_one_validator + name="temperature", required=False, default=0.0, validator=zero_to_two_validator ) + self._top_p = Parameter.float(name="top_p", required=False, validator=zero_to_one_validator) def model_name(self) -> str: return self._model.unwrap() @@ -42,15 +58,55 @@ def api_key(self) -> Parameter[str]: """ return self._api_key + @property + def frequency_penalty(self) -> Parameter[float]: + """ + Number between -2.0 and 2.0. + Positive values penalize new tokens based on their existing frequency in the text so far, + decreasing the model's likelihood to repeat the same line verbatim. + """ + return self._frequency_penalty + + @property + def max_tokens(self) -> Parameter[int]: + """Maximum number of tokens to generate.""" + return self._max_tokens + + @property + def presence_penalty(self) -> Parameter[float]: + """ + Number between -2.0 and 2.0. + Positive values penalize new tokens based on whether they appear in the text so far, + increasing the model's likelihood to talk about new topics. + """ + return self._presence_penalty + + @property + def seed(self) -> Parameter[int]: + """Random seed for generation.""" + return self._seed + + @property + def stop(self) -> Parameter[str]: + """Stop sequence.""" + return self._stop + + @property + def stop_value(self) -> Optional[List[str]]: + """Format `stop` parameter. Only single value is supported currently.""" + return [self.stop.value] if self.stop.value is not None else None + @property def temperature(self) -> Parameter[float]: """ - Amount of randomness injected into the response. - Ranges from 0 to 1. + What sampling temperature to use, between 0 and 2. """ return self._temperature - # TODO: more parameters + @property + def top_p(self) -> Parameter[float]: + """Nucleus sampling threshold.""" + return self._top_p @staticmethod def from_env() -> GroqLLMConfiguration: @@ -66,8 +122,19 @@ def from_spec(spec: LLMConfigSpec) -> GroqLLMConfiguration: config = GroqLLMConfiguration(model=str(model), api_key=str(api_key)) if spec.parameters is not None: - value: Optional[Any] = spec.parameters.get("temperature", None) - if value is not None: - config.temperature.set(float(value)) + param_mapping: Mapping[str, Tuple[Parameter, Type]] = { + "frequencyPenalty": (config.frequency_penalty, float), + "maxTokens": (config.max_tokens, int), + "presencePenalty": (config.presence_penalty, float), + "seed": (config.seed, int), + "stop": (config.stop, list), + "temperature": (config.temperature, float), + "topP": (config.top_p, float), + } + + for key, (param, type_conv) in param_mapping.items(): + value = spec.parameters.get(key) + if value is not None: + param.set(type_conv(value)) return config From 93a9df6d9fde3ba2432f5eecd7995207be04a368 Mon Sep 17 00:00:00 2001 From: Dmytro Nikolaiev Date: Mon, 18 Nov 2024 16:27:56 -0500 Subject: [PATCH 10/14] Test parameters --- council/llm/groq_llm.py | 3 +- council/llm/groq_llm_configuration.py | 19 +++++++---- tests/integration/llm/test_groq_llm.py | 45 +++++++++++++++++++------- 3 files changed, 48 insertions(+), 19 deletions(-) diff --git a/council/llm/groq_llm.py b/council/llm/groq_llm.py index 8c116e07..4580815a 100644 --- a/council/llm/groq_llm.py +++ b/council/llm/groq_llm.py @@ -82,7 +82,8 @@ def _post_chat_request(self, context: LLMContext, messages: Sequence[LLMMessage] response = self._client.chat.completions.create( messages=formatted_messages, model=self._configuration.model_name(), - temperature=self._configuration.temperature.value, + **self._configuration.params_to_args(), + **kwargs, ) return LLMResult( diff --git a/council/llm/groq_llm_configuration.py b/council/llm/groq_llm_configuration.py index 44b55cbf..71a0bbcb 100644 --- a/council/llm/groq_llm_configuration.py +++ b/council/llm/groq_llm_configuration.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Final, List, Mapping, Optional, Tuple, Type +from typing import Any, Dict, Final, Mapping, Tuple, Type from council.utils import ( Parameter, @@ -91,11 +91,6 @@ def stop(self) -> Parameter[str]: """Stop sequence.""" return self._stop - @property - def stop_value(self) -> Optional[List[str]]: - """Format `stop` parameter. Only single value is supported currently.""" - return [self.stop.value] if self.stop.value is not None else None - @property def temperature(self) -> Parameter[float]: """ @@ -108,6 +103,18 @@ def top_p(self) -> Parameter[float]: """Nucleus sampling threshold.""" return self._top_p + def params_to_args(self) -> Dict[str, Any]: + """Convert parameters to options dict""" + return { + "frequency_penalty": self.frequency_penalty.value, + "max_tokens": self.max_tokens.value, + "presence_penalty": self.presence_penalty.value, + "seed": self.seed.value, + "stop": self.stop.value, + "temperature": self.temperature.value, + "top_p": self.top_p.value, + } + @staticmethod def from_env() -> GroqLLMConfiguration: api_key = read_env_str(_env_var_prefix + "API_KEY").unwrap() diff --git a/tests/integration/llm/test_groq_llm.py b/tests/integration/llm/test_groq_llm.py index a88efe42..a61d01ee 100644 --- a/tests/integration/llm/test_groq_llm.py +++ b/tests/integration/llm/test_groq_llm.py @@ -1,33 +1,54 @@ import unittest import dotenv +import json + from council import LLMContext from council.llm import LLMMessage, GroqLLM from council.utils import OsEnviron class TestGroqLLM(unittest.TestCase): + @staticmethod + def get_gemma_7b(): + dotenv.load_dotenv() + with OsEnviron("GROQ_LLM_MODEL", "gemma-7b-it"): + return GroqLLM.from_env() + def test_message(self): messages = [LLMMessage.user_message("what is the capital of France?")] - dotenv.load_dotenv() - with OsEnviron("GROQ_LLM_MODEL", "llama-3.2-1b-preview"): - instance = GroqLLM.from_env() + llm = self.get_gemma_7b() - result = instance.post_chat_request(LLMContext.empty(), messages) + result = llm.post_chat_request(LLMContext.empty(), messages) assert "Paris" in result.choices[0] messages.append(LLMMessage.user_message("give a famous monument of that place")) - result = instance.post_chat_request(LLMContext.empty(), messages) + result = llm.post_chat_request(LLMContext.empty(), messages) assert "Eiffel" in result.choices[0] def test_consumptions(self): messages = [LLMMessage.user_message("Hello how are you?")] - dotenv.load_dotenv() - with OsEnviron("GROQ_LLM_MODEL", "llama-3.2-1b-preview"): - instance = GroqLLM.from_env() - result = instance.post_chat_request(LLMContext.empty(), messages) + llm = self.get_gemma_7b() + result = llm.post_chat_request(LLMContext.empty(), messages) + + assert len(result.consumptions) == 12 # call, duration, 3 token kinds, 3 cost kinds and 4 groq duration + for consumption in result.consumptions: + assert consumption.kind.startswith("gemma-7b-it") + + def test_max_tokens_param(self): + llm = self.get_gemma_7b() + llm.configuration.temperature.set(0.8) + llm.configuration.max_tokens.set(7) + + messages = [LLMMessage.user_message("Hey how are you?")] + result = llm.post_chat_request(LLMContext.empty(), messages) + print(f"Predicted: {result.first_choice}") + + def test_json_mode(self): + messages = [LLMMessage.user_message("Output a JSON object with the data about RPG character.")] + llm = self.get_gemma_7b() + result = llm.post_chat_request(LLMContext.empty(), messages, response_format={"type": "json_object"}) - assert len(result.consumptions) == 12 # call, duration, 3 token kinds, 3 cost kinds and 4 groq duration - for consumption in result.consumptions: - assert consumption.kind.startswith("llama-3.2-1b-preview") + data = json.loads(result.first_choice) + assert isinstance(data, dict) From 1ff3ad82481e6ed57ca2d8a43903de2daa4fff5f Mon Sep 17 00:00:00 2001 From: Dmytro Nikolaiev Date: Mon, 18 Nov 2024 16:33:19 -0500 Subject: [PATCH 11/14] Add costs --- council/llm/data/groq-costs.yaml | 39 ++++++++++++++++++++++++++++++++ 1 file changed, 39 insertions(+) diff --git a/council/llm/data/groq-costs.yaml b/council/llm/data/groq-costs.yaml index 942fadd7..dc4cde42 100644 --- a/council/llm/data/groq-costs.yaml +++ b/council/llm/data/groq-costs.yaml @@ -10,6 +10,45 @@ spec: description: | Costs for Groq LLMs models: + gemma2-9b-it: + input: 0.20 + output: 0.20 + gemma-7b-it: + input: 0.07 + output: 0.07 + llama3-groq-70b-8192-tool-use-preview: + input: 0.89 + output: 0.89 + llama3-groq-8b-8192-tool-use-preview: + input: 0.19 + output: 0.19 + llama-3.1-70b-versatile: + input: 0.59 + output: 0.79 + llama-3.1-8b-instant: + input: 0.05 + output: 0.08 llama-3.2-1b-preview: input: 0.04 output: 0.04 + llama-3.2-3b-preview: + input: 0.06 + output: 0.06 + llama-3.2-11b-vision-preview: + input: 0.18 + output: 0.18 + llama-3.2-90b-vision-preview: + input: 0.90 + output: 0.90 + llama-guard-3-8b: + input: 0.20 + output: 0.20 + llama3-70b-8192: + input: 0.59 + output: 0.79 + llama3-8b-8192: + input: 0.05 + output: 0.08 + mixtral-8x7b-32768: + input: 0.24 + output: 0.24 \ No newline at end of file From 89f099696697f637aabb8ddad06bfd3594359ecd Mon Sep 17 00:00:00 2001 From: Dmytro Nikolaiev Date: Mon, 18 Nov 2024 16:46:17 -0500 Subject: [PATCH 12/14] Update docs --- docs/data/configs/llm-config-groq.yaml | 2 ++ docs/source/reference/llm/groq_llm.md | 12 ++++++++++++ docs/source/reference/llm/groq_llm_configuration.md | 6 ++++++ 3 files changed, 20 insertions(+) create mode 100644 docs/source/reference/llm/groq_llm.md create mode 100644 docs/source/reference/llm/groq_llm_configuration.md diff --git a/docs/data/configs/llm-config-groq.yaml b/docs/data/configs/llm-config-groq.yaml index c14c19a8..61c4e4b0 100644 --- a/docs/data/configs/llm-config-groq.yaml +++ b/docs/data/configs/llm-config-groq.yaml @@ -14,4 +14,6 @@ spec: apiKey: fromEnvVar: GROQ_API_KEY parameters: + maxTokens: 128 + seed: 42 temperature: 0 diff --git a/docs/source/reference/llm/groq_llm.md b/docs/source/reference/llm/groq_llm.md new file mode 100644 index 00000000..94d63e45 --- /dev/null +++ b/docs/source/reference/llm/groq_llm.md @@ -0,0 +1,12 @@ +# GroqLLM + +```{eval-rst} +.. autoclasstree:: council.llm.GroqLLM + :full: + :namespace: council +``` + +```{eval-rst} +.. autoclass:: council.llm.GroqLLM + :member-order: bysource +``` diff --git a/docs/source/reference/llm/groq_llm_configuration.md b/docs/source/reference/llm/groq_llm_configuration.md new file mode 100644 index 00000000..17dfa39b --- /dev/null +++ b/docs/source/reference/llm/groq_llm_configuration.md @@ -0,0 +1,6 @@ +# GroqLLMConfiguration + +```{eval-rst} +.. autoclass:: council.llm.GroqLLMConfiguration + :member-order: bysource +``` From b733596dacc159862accb0c252b5b3c0b8d3da80 Mon Sep 17 00:00:00 2001 From: Dmytro Nikolaiev Date: Tue, 19 Nov 2024 10:48:43 -0500 Subject: [PATCH 13/14] Separate DefaultLLMConsumptionCalculatorHelper to use in GroqConsumptionCalculator --- council/llm/__init__.py | 1 + council/llm/groq_llm.py | 9 ++++----- council/llm/llm_cost.py | 32 +++++++++++++++++--------------- 3 files changed, 22 insertions(+), 20 deletions(-) diff --git a/council/llm/__init__.py b/council/llm/__init__.py index 8c893684..4617f623 100644 --- a/council/llm/__init__.py +++ b/council/llm/__init__.py @@ -13,6 +13,7 @@ LLMCostCard, LLMConsumptionCalculatorBase, DefaultLLMConsumptionCalculator, + DefaultLLMConsumptionCalculatorHelper, TokenKind, LLMCostManagerSpec, LLMCostManagerObject, diff --git a/council/llm/groq_llm.py b/council/llm/groq_llm.py index 4580815a..a162c0d2 100644 --- a/council/llm/groq_llm.py +++ b/council/llm/groq_llm.py @@ -4,7 +4,7 @@ from council.contexts import Consumption, LLMContext from council.llm import ( - DefaultLLMConsumptionCalculator, + DefaultLLMConsumptionCalculatorHelper, GroqLLMConfiguration, LLMBase, LLMConfigObject, @@ -27,15 +27,14 @@ from groq.types.chat.chat_completion import ChatCompletion, Choice -class GroqConsumptionCalculator(DefaultLLMConsumptionCalculator): +class GroqConsumptionCalculator(DefaultLLMConsumptionCalculatorHelper): _cost_manager = LLMCostManagerObject.groq() COSTS: Mapping[str, LLMCostCard] = _cost_manager.get_cost_map("default") def __init__(self, model: str) -> None: super().__init__(model) - # TODO: naming - def get_groq_consumptions(self, duration: float, usage: Optional[CompletionUsage]) -> List[Consumption]: + def get_consumptions(self, duration: float, usage: Optional[CompletionUsage]) -> List[Consumption]: if usage is None: return self.get_default_consumptions(duration) @@ -113,7 +112,7 @@ def _to_choices(choices: List[Choice]) -> List[str]: @staticmethod def _to_consumptions(duration: float, response: ChatCompletion) -> Sequence[Consumption]: calculator = GroqConsumptionCalculator(response.model) - return calculator.get_groq_consumptions(duration, response.usage) + return calculator.get_consumptions(duration, response.usage) @staticmethod def from_env() -> GroqLLM: diff --git a/council/llm/llm_cost.py b/council/llm/llm_cost.py index 51322ef6..88347e54 100644 --- a/council/llm/llm_cost.py +++ b/council/llm/llm_cost.py @@ -103,21 +103,7 @@ def filter_zeros(consumptions: List[Consumption]) -> List[Consumption]: return list(filter(lambda consumption: consumption.value > 0, consumptions)) -class DefaultLLMConsumptionCalculator(LLMConsumptionCalculatorBase, abc.ABC): - def get_consumptions(self, duration: float, *, prompt_tokens: int, completion_tokens: int) -> List[Consumption]: - """ - Get default consumptions: - - 1 call - - specified duration - - prompt, completion and total tokens - - corresponding costs if LLMCostCard can be found. - """ - base_consumptions = self.get_base_consumptions( - duration, prompt_tokens=prompt_tokens, completion_tokens=completion_tokens - ) - cost_consumptions = self.get_cost_consumptions(prompt_tokens=prompt_tokens, completion_tokens=completion_tokens) - return base_consumptions + cost_consumptions - +class DefaultLLMConsumptionCalculatorHelper(LLMConsumptionCalculatorBase, abc.ABC): def get_base_consumptions( self, duration: float, *, prompt_tokens: int, completion_tokens: int ) -> List[Consumption]: @@ -142,6 +128,22 @@ def get_cost_consumptions(self, *, prompt_tokens: int, completion_tokens: int) - ] +class DefaultLLMConsumptionCalculator(DefaultLLMConsumptionCalculatorHelper, abc.ABC): + def get_consumptions(self, duration: float, *, prompt_tokens: int, completion_tokens: int) -> List[Consumption]: + """ + Get default consumptions: + - 1 call + - specified duration + - prompt, completion and total tokens + - corresponding costs if LLMCostCard can be found. + """ + base_consumptions = self.get_base_consumptions( + duration, prompt_tokens=prompt_tokens, completion_tokens=completion_tokens + ) + cost_consumptions = self.get_cost_consumptions(prompt_tokens=prompt_tokens, completion_tokens=completion_tokens) + return base_consumptions + cost_consumptions + + class LLMCostManagerSpec(DataObjectSpecBase): def __init__(self, costs: Dict[str, Dict[str, LLMCostCard]]) -> None: """ From 2259ec5d2ee8ebb680bb7957dcd54b112e3915ac Mon Sep 17 00:00:00 2001 From: Dmytro Nikolaiev Date: Tue, 19 Nov 2024 11:02:50 -0500 Subject: [PATCH 14/14] Update get_duration_consumptions() --- council/llm/groq_llm.py | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/council/llm/groq_llm.py b/council/llm/groq_llm.py index a162c0d2..63a1c6ee 100644 --- a/council/llm/groq_llm.py +++ b/council/llm/groq_llm.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Any, List, Mapping, Optional, Sequence +from typing import Any, Dict, List, Mapping, Optional, Sequence from council.contexts import Consumption, LLMContext from council.llm import ( @@ -48,15 +48,18 @@ def get_consumptions(self, duration: float, usage: Optional[CompletionUsage]) -> def get_duration_consumptions(self, usage: CompletionUsage) -> List[Consumption]: """Optional duration consumptions specific to Groq.""" + usage_times: Dict[str, Optional[float]] = { + "queue_time": usage.queue_time, + "prompt_time": usage.prompt_time, + "completion_time": usage.completion_time, + "total_time": usage.total_time, + } + consumptions = [] - if usage.queue_time is not None: - consumptions.append(Consumption.duration(usage.queue_time, f"{self.model}:groq_queue_time")) - if usage.prompt_time is not None: - consumptions.append(Consumption.duration(usage.prompt_time, f"{self.model}:groq_prompt_time")) - if usage.completion_time is not None: - consumptions.append(Consumption.duration(usage.completion_time, f"{self.model}:groq_completion_time")) - if usage.total_time is not None: - consumptions.append(Consumption.duration(usage.total_time, f"{self.model}:groq_total_time")) + for key, value in usage_times.items(): + if value is not None: + consumptions.append(Consumption.duration(value, f"{self.model}:groq_{key}")) + return consumptions def find_model_costs(self) -> Optional[LLMCostCard]: