Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature initial groq implementation #190

Merged
merged 14 commits into from
Nov 20, 2024
8 changes: 8 additions & 0 deletions council/llm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
LLMCostCard,
LLMConsumptionCalculatorBase,
DefaultLLMConsumptionCalculator,
DefaultLLMConsumptionCalculatorHelper,
TokenKind,
LLMCostManagerSpec,
LLMCostManagerObject,
Expand Down Expand Up @@ -61,6 +62,9 @@
from .ollama_llm_configuration import OllamaLLMConfiguration
from .ollama_llm import OllamaLLM

from .groq_llm_configuration import GroqLLMConfiguration
from .groq_llm import GroqLLM


def get_default_llm(max_retries: Optional[int] = None) -> LLMBase:
provider = read_env_str("COUNCIL_DEFAULT_LLM_PROVIDER", default=LLMProviders.OpenAI).unwrap()
Expand All @@ -77,6 +81,8 @@ def get_default_llm(max_retries: Optional[int] = None) -> LLMBase:
llm = GeminiLLM.from_env()
elif provider == LLMProviders.Ollama.lower():
llm = OllamaLLM.from_env()
elif provider == LLMProviders.Groq.lower():
llm = GroqLLM.from_env()

if llm is None:
raise ValueError(f"Provider {provider} not supported by Council.")
Expand Down Expand Up @@ -111,5 +117,7 @@ def _build_llm(llm_config: LLMConfigObject) -> LLMBase:
return GeminiLLM.from_config(llm_config)
elif provider.is_of_kind(LLMProviders.Ollama):
return OllamaLLM.from_config(llm_config)
elif provider.is_of_kind(LLMProviders.Groq):
return GroqLLM.from_config(llm_config)

raise ValueError(f"Provider `{provider.kind}` not supported by Council")
43 changes: 9 additions & 34 deletions council/llm/chat_gpt_configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,34 +2,7 @@
from typing import Any, Dict, Optional

from council.llm.llm_base import LLMConfigurationBase
from council.utils import Parameter


def _tv(x: float):
"""
Temperature Validator
Sampling temperature to use, between 0. and 2.
"""
if x < 0.0 or x > 2.0:
raise ValueError("must be in the range [0.0..2.0]")


def _pv(x: float):
"""
Penalty Validator
Penalty must be between -2.0 and 2.0
"""
if x < -2.0 or x > 2.0:
raise ValueError("must be in the range [-2.0..2.0]")


def _mtv(x: int):
"""
Max Token Validator
Must be positive
"""
if x <= 0:
raise ValueError("must be positive")
from council.utils import Parameter, penalty_validator, positive_validator, zero_to_one_validator, zero_to_two_validator


class ChatGPTConfigurationBase(LLMConfigurationBase, ABC):
Expand All @@ -38,12 +11,14 @@ class ChatGPTConfigurationBase(LLMConfigurationBase, ABC):
"""

def __init__(self) -> None:
self._temperature = Parameter.float(name="temperature", required=False, default=0.0, validator=_tv)
self._max_tokens = Parameter.int(name="max_tokens", required=False, validator=_mtv)
self._top_p = Parameter.float(name="top_p", required=False)
self._n = Parameter.int(name="n", required=False, default=1)
self._presence_penalty = Parameter.float(name="presence_penalty", required=False, validator=_pv)
self._frequency_penalty = Parameter.float(name="frequency_penalty", required=False, validator=_pv)
self._temperature = Parameter.float(
name="temperature", required=False, default=0.0, validator=zero_to_two_validator
)
self._max_tokens = Parameter.int(name="max_tokens", required=False, validator=positive_validator)
self._top_p = Parameter.float(name="top_p", required=False, validator=zero_to_one_validator)
self._n = Parameter.int(name="n", required=False, default=1, validator=positive_validator)
self._presence_penalty = Parameter.float(name="presence_penalty", required=False, validator=penalty_validator)
self._frequency_penalty = Parameter.float(name="frequency_penalty", required=False, validator=penalty_validator)

@property
def temperature(self) -> Parameter[float]:
Expand Down
54 changes: 54 additions & 0 deletions council/llm/data/groq-costs.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
kind: LLMCostManager
version: 0.1
metadata:
name: groq-costs
labels:
provider: Groq
reference: https://groq.com/pricing/
spec:
default:
description: |
Costs for Groq LLMs
models:
gemma2-9b-it:
input: 0.20
output: 0.20
gemma-7b-it:
input: 0.07
output: 0.07
llama3-groq-70b-8192-tool-use-preview:
input: 0.89
output: 0.89
llama3-groq-8b-8192-tool-use-preview:
input: 0.19
output: 0.19
llama-3.1-70b-versatile:
input: 0.59
output: 0.79
llama-3.1-8b-instant:
input: 0.05
output: 0.08
llama-3.2-1b-preview:
input: 0.04
output: 0.04
llama-3.2-3b-preview:
input: 0.06
output: 0.06
llama-3.2-11b-vision-preview:
input: 0.18
output: 0.18
llama-3.2-90b-vision-preview:
input: 0.90
output: 0.90
llama-guard-3-8b:
input: 0.20
output: 0.20
llama3-70b-8192:
input: 0.59
output: 0.79
llama3-8b-8192:
input: 0.05
output: 0.08
mixtral-8x7b-32768:
input: 0.24
output: 0.24
137 changes: 137 additions & 0 deletions council/llm/groq_llm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
from __future__ import annotations

from typing import Any, Dict, List, Mapping, Optional, Sequence

from council.contexts import Consumption, LLMContext
from council.llm import (
DefaultLLMConsumptionCalculatorHelper,
GroqLLMConfiguration,
LLMBase,
LLMConfigObject,
LLMCostCard,
LLMCostManagerObject,
LLMMessage,
LLMMessageRole,
LLMProviders,
LLMResult,
)
from council.utils.utils import DurationManager
from groq import Groq
from groq.types import CompletionUsage
from groq.types.chat import (
ChatCompletionAssistantMessageParam,
ChatCompletionMessageParam,
ChatCompletionSystemMessageParam,
ChatCompletionUserMessageParam,
)
from groq.types.chat.chat_completion import ChatCompletion, Choice


class GroqConsumptionCalculator(DefaultLLMConsumptionCalculatorHelper):
_cost_manager = LLMCostManagerObject.groq()
COSTS: Mapping[str, LLMCostCard] = _cost_manager.get_cost_map("default")

def __init__(self, model: str) -> None:
super().__init__(model)

def get_consumptions(self, duration: float, usage: Optional[CompletionUsage]) -> List[Consumption]:
if usage is None:
return self.get_default_consumptions(duration)

prompt_tokens = usage.prompt_tokens
completion_tokens = usage.completion_tokens
return (
self.get_base_consumptions(duration, prompt_tokens=prompt_tokens, completion_tokens=completion_tokens)
+ self.get_duration_consumptions(usage)
+ self.get_cost_consumptions(prompt_tokens=prompt_tokens, completion_tokens=completion_tokens)
)

def get_duration_consumptions(self, usage: CompletionUsage) -> List[Consumption]:
"""Optional duration consumptions specific to Groq."""
usage_times: Dict[str, Optional[float]] = {
"queue_time": usage.queue_time,
"prompt_time": usage.prompt_time,
"completion_time": usage.completion_time,
"total_time": usage.total_time,
}

consumptions = []
for key, value in usage_times.items():
if value is not None:
consumptions.append(Consumption.duration(value, f"{self.model}:groq_{key}"))

return consumptions

def find_model_costs(self) -> Optional[LLMCostCard]:
return self.COSTS.get(self.model)


class GroqLLM(LLMBase[GroqLLMConfiguration]):
def __init__(self, config: GroqLLMConfiguration) -> None:
"""
Initialize a new instance.

Args:
config(GroqLLMConfiguration): configuration for the instance
"""
super().__init__(name=f"{self.__class__.__name__}", configuration=config)
self._client = Groq(api_key=config.api_key.value)

def _post_chat_request(self, context: LLMContext, messages: Sequence[LLMMessage], **kwargs: Any) -> LLMResult:
formatted_messages = self._build_messages_payload(messages)

with DurationManager() as timer:
response = self._client.chat.completions.create(
messages=formatted_messages,
model=self._configuration.model_name(),
**self._configuration.params_to_args(),
**kwargs,
)

return LLMResult(
choices=self._to_choices(response.choices),
consumptions=self._to_consumptions(timer.duration, response),
raw_response=response.to_dict(),
)

@staticmethod
def _build_messages_payload(messages: Sequence[LLMMessage]) -> List[ChatCompletionMessageParam]:
def _llm_message_to_groq_message(message: LLMMessage) -> ChatCompletionMessageParam:
if message.is_of_role(LLMMessageRole.System):
return ChatCompletionSystemMessageParam(role="system", content=message.content)
elif message.is_of_role(LLMMessageRole.User):
return ChatCompletionUserMessageParam(role="user", content=message.content)
elif message.is_of_role(LLMMessageRole.Assistant):
return ChatCompletionAssistantMessageParam(role="assistant", content=message.content)

raise ValueError(f"Unknown LLMessage role: `{message.role.value}`")

return [_llm_message_to_groq_message(message) for message in messages]

@staticmethod
def _to_choices(choices: List[Choice]) -> List[str]:
return [choice.message.content if choice.message.content is not None else "" for choice in choices]

@staticmethod
def _to_consumptions(duration: float, response: ChatCompletion) -> Sequence[Consumption]:
calculator = GroqConsumptionCalculator(response.model)
return calculator.get_consumptions(duration, response.usage)

@staticmethod
def from_env() -> GroqLLM:
"""
Helper function that create a new instance by getting the configuration from environment variables.

Returns:
GroqLLM
"""
return GroqLLM(GroqLLMConfiguration.from_env())

@staticmethod
def from_config(config_object: LLMConfigObject) -> GroqLLM:
provider = config_object.spec.provider
if not provider.is_of_kind(LLMProviders.Groq):
raise ValueError(f"Invalid LLM provider, actual {provider}, expected {LLMProviders.Groq}")

config = GroqLLMConfiguration.from_spec(config_object.spec)
return GroqLLM(config=config)
Loading
Loading