Skip to content

Commit

Permalink
Add more parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
Winston-503 committed Nov 18, 2024
1 parent ef5c831 commit a96be04
Showing 1 changed file with 77 additions and 10 deletions.
87 changes: 77 additions & 10 deletions council/llm/groq_llm_configuration.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,16 @@
from __future__ import annotations

from typing import Any, Final, Optional

from council.utils import Parameter, not_empty_validator, read_env_str, zero_to_one_validator
from typing import Final, List, Mapping, Optional, Tuple, Type

from council.utils import (
Parameter,
not_empty_validator,
penalty_validator,
positive_validator,
read_env_str,
zero_to_one_validator,
zero_to_two_validator,
)

from . import LLMConfigSpec, LLMConfigurationBase

Expand All @@ -21,9 +29,17 @@ def __init__(self, model: str, api_key: str) -> None:
super().__init__()
self._model = Parameter.string(name="model", required=True, value=model, validator=not_empty_validator)
self._api_key = Parameter.string(name="api_key", required=True, value=api_key, validator=not_empty_validator)

# https://console.groq.com/docs/api-reference#chat
self._frequency_penalty = Parameter.float(name="frequency_penalty", required=False, validator=penalty_validator)
self._max_tokens = Parameter.int(name="max_tokens", required=False, validator=positive_validator)
self._presence_penalty = Parameter.float(name="presence_penalty", required=False, validator=penalty_validator)
self._seed = Parameter.int(name="seed", required=False)
self._stop = Parameter.string(name="stop", required=False)
self._temperature = Parameter.float(
name="temperature", required=False, default=0.0, validator=zero_to_one_validator
name="temperature", required=False, default=0.0, validator=zero_to_two_validator
)
self._top_p = Parameter.float(name="top_p", required=False, validator=zero_to_one_validator)

def model_name(self) -> str:
return self._model.unwrap()
Expand All @@ -42,15 +58,55 @@ def api_key(self) -> Parameter[str]:
"""
return self._api_key

@property
def frequency_penalty(self) -> Parameter[float]:
"""
Number between -2.0 and 2.0.
Positive values penalize new tokens based on their existing frequency in the text so far,
decreasing the model's likelihood to repeat the same line verbatim.
"""
return self._frequency_penalty

@property
def max_tokens(self) -> Parameter[int]:
"""Maximum number of tokens to generate."""
return self._max_tokens

@property
def presence_penalty(self) -> Parameter[float]:
"""
Number between -2.0 and 2.0.
Positive values penalize new tokens based on whether they appear in the text so far,
increasing the model's likelihood to talk about new topics.
"""
return self._presence_penalty

@property
def seed(self) -> Parameter[int]:
"""Random seed for generation."""
return self._seed

@property
def stop(self) -> Parameter[str]:
"""Stop sequence."""
return self._stop

@property
def stop_value(self) -> Optional[List[str]]:
"""Format `stop` parameter. Only single value is supported currently."""
return [self.stop.value] if self.stop.value is not None else None

@property
def temperature(self) -> Parameter[float]:
"""
Amount of randomness injected into the response.
Ranges from 0 to 1.
What sampling temperature to use, between 0 and 2.
"""
return self._temperature

# TODO: more parameters
@property
def top_p(self) -> Parameter[float]:
"""Nucleus sampling threshold."""
return self._top_p

@staticmethod
def from_env() -> GroqLLMConfiguration:
Expand All @@ -66,8 +122,19 @@ def from_spec(spec: LLMConfigSpec) -> GroqLLMConfiguration:
config = GroqLLMConfiguration(model=str(model), api_key=str(api_key))

if spec.parameters is not None:
value: Optional[Any] = spec.parameters.get("temperature", None)
if value is not None:
config.temperature.set(float(value))
param_mapping: Mapping[str, Tuple[Parameter, Type]] = {
"frequencyPenalty": (config.frequency_penalty, float),
"maxTokens": (config.max_tokens, int),
"presencePenalty": (config.presence_penalty, float),
"seed": (config.seed, int),
"stop": (config.stop, list),
"temperature": (config.temperature, float),
"topP": (config.top_p, float),
}

for key, (param, type_conv) in param_mapping.items():
value = spec.parameters.get(key)
if value is not None:
param.set(type_conv(value))

return config

0 comments on commit a96be04

Please sign in to comment.