Skip to content

Commit

Permalink
Test parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
Winston-503 committed Nov 18, 2024
1 parent a96be04 commit 93a9df6
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 19 deletions.
3 changes: 2 additions & 1 deletion council/llm/groq_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,8 @@ def _post_chat_request(self, context: LLMContext, messages: Sequence[LLMMessage]
response = self._client.chat.completions.create(
messages=formatted_messages,
model=self._configuration.model_name(),
temperature=self._configuration.temperature.value,
**self._configuration.params_to_args(),
**kwargs,
)

return LLMResult(
Expand Down
19 changes: 13 additions & 6 deletions council/llm/groq_llm_configuration.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import Final, List, Mapping, Optional, Tuple, Type
from typing import Any, Dict, Final, Mapping, Tuple, Type

from council.utils import (
Parameter,
Expand Down Expand Up @@ -91,11 +91,6 @@ def stop(self) -> Parameter[str]:
"""Stop sequence."""
return self._stop

@property
def stop_value(self) -> Optional[List[str]]:
"""Format `stop` parameter. Only single value is supported currently."""
return [self.stop.value] if self.stop.value is not None else None

@property
def temperature(self) -> Parameter[float]:
"""
Expand All @@ -108,6 +103,18 @@ def top_p(self) -> Parameter[float]:
"""Nucleus sampling threshold."""
return self._top_p

def params_to_args(self) -> Dict[str, Any]:
"""Convert parameters to options dict"""
return {
"frequency_penalty": self.frequency_penalty.value,
"max_tokens": self.max_tokens.value,
"presence_penalty": self.presence_penalty.value,
"seed": self.seed.value,
"stop": self.stop.value,
"temperature": self.temperature.value,
"top_p": self.top_p.value,
}

@staticmethod
def from_env() -> GroqLLMConfiguration:
api_key = read_env_str(_env_var_prefix + "API_KEY").unwrap()
Expand Down
45 changes: 33 additions & 12 deletions tests/integration/llm/test_groq_llm.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,54 @@
import unittest
import dotenv

import json

from council import LLMContext
from council.llm import LLMMessage, GroqLLM
from council.utils import OsEnviron


class TestGroqLLM(unittest.TestCase):
@staticmethod
def get_gemma_7b():
dotenv.load_dotenv()
with OsEnviron("GROQ_LLM_MODEL", "gemma-7b-it"):
return GroqLLM.from_env()

def test_message(self):
messages = [LLMMessage.user_message("what is the capital of France?")]
dotenv.load_dotenv()
with OsEnviron("GROQ_LLM_MODEL", "llama-3.2-1b-preview"):
instance = GroqLLM.from_env()
llm = self.get_gemma_7b()

result = instance.post_chat_request(LLMContext.empty(), messages)
result = llm.post_chat_request(LLMContext.empty(), messages)
assert "Paris" in result.choices[0]

messages.append(LLMMessage.user_message("give a famous monument of that place"))
result = instance.post_chat_request(LLMContext.empty(), messages)
result = llm.post_chat_request(LLMContext.empty(), messages)

assert "Eiffel" in result.choices[0]

def test_consumptions(self):
messages = [LLMMessage.user_message("Hello how are you?")]
dotenv.load_dotenv()
with OsEnviron("GROQ_LLM_MODEL", "llama-3.2-1b-preview"):
instance = GroqLLM.from_env()
result = instance.post_chat_request(LLMContext.empty(), messages)
llm = self.get_gemma_7b()
result = llm.post_chat_request(LLMContext.empty(), messages)

assert len(result.consumptions) == 12 # call, duration, 3 token kinds, 3 cost kinds and 4 groq duration
for consumption in result.consumptions:
assert consumption.kind.startswith("gemma-7b-it")

def test_max_tokens_param(self):
llm = self.get_gemma_7b()
llm.configuration.temperature.set(0.8)
llm.configuration.max_tokens.set(7)

messages = [LLMMessage.user_message("Hey how are you?")]
result = llm.post_chat_request(LLMContext.empty(), messages)
print(f"Predicted: {result.first_choice}")

def test_json_mode(self):
messages = [LLMMessage.user_message("Output a JSON object with the data about RPG character.")]
llm = self.get_gemma_7b()
result = llm.post_chat_request(LLMContext.empty(), messages, response_format={"type": "json_object"})

assert len(result.consumptions) == 12 # call, duration, 3 token kinds, 3 cost kinds and 4 groq duration
for consumption in result.consumptions:
assert consumption.kind.startswith("llama-3.2-1b-preview")
data = json.loads(result.first_choice)
assert isinstance(data, dict)

0 comments on commit 93a9df6

Please sign in to comment.