From 93a9df6d9fde3ba2432f5eecd7995207be04a368 Mon Sep 17 00:00:00 2001 From: Dmytro Nikolaiev Date: Mon, 18 Nov 2024 16:27:56 -0500 Subject: [PATCH] Test parameters --- council/llm/groq_llm.py | 3 +- council/llm/groq_llm_configuration.py | 19 +++++++---- tests/integration/llm/test_groq_llm.py | 45 +++++++++++++++++++------- 3 files changed, 48 insertions(+), 19 deletions(-) diff --git a/council/llm/groq_llm.py b/council/llm/groq_llm.py index 8c116e07..4580815a 100644 --- a/council/llm/groq_llm.py +++ b/council/llm/groq_llm.py @@ -82,7 +82,8 @@ def _post_chat_request(self, context: LLMContext, messages: Sequence[LLMMessage] response = self._client.chat.completions.create( messages=formatted_messages, model=self._configuration.model_name(), - temperature=self._configuration.temperature.value, + **self._configuration.params_to_args(), + **kwargs, ) return LLMResult( diff --git a/council/llm/groq_llm_configuration.py b/council/llm/groq_llm_configuration.py index 44b55cbf..71a0bbcb 100644 --- a/council/llm/groq_llm_configuration.py +++ b/council/llm/groq_llm_configuration.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Final, List, Mapping, Optional, Tuple, Type +from typing import Any, Dict, Final, Mapping, Tuple, Type from council.utils import ( Parameter, @@ -91,11 +91,6 @@ def stop(self) -> Parameter[str]: """Stop sequence.""" return self._stop - @property - def stop_value(self) -> Optional[List[str]]: - """Format `stop` parameter. Only single value is supported currently.""" - return [self.stop.value] if self.stop.value is not None else None - @property def temperature(self) -> Parameter[float]: """ @@ -108,6 +103,18 @@ def top_p(self) -> Parameter[float]: """Nucleus sampling threshold.""" return self._top_p + def params_to_args(self) -> Dict[str, Any]: + """Convert parameters to options dict""" + return { + "frequency_penalty": self.frequency_penalty.value, + "max_tokens": self.max_tokens.value, + "presence_penalty": self.presence_penalty.value, + "seed": self.seed.value, + "stop": self.stop.value, + "temperature": self.temperature.value, + "top_p": self.top_p.value, + } + @staticmethod def from_env() -> GroqLLMConfiguration: api_key = read_env_str(_env_var_prefix + "API_KEY").unwrap() diff --git a/tests/integration/llm/test_groq_llm.py b/tests/integration/llm/test_groq_llm.py index a88efe42..a61d01ee 100644 --- a/tests/integration/llm/test_groq_llm.py +++ b/tests/integration/llm/test_groq_llm.py @@ -1,33 +1,54 @@ import unittest import dotenv +import json + from council import LLMContext from council.llm import LLMMessage, GroqLLM from council.utils import OsEnviron class TestGroqLLM(unittest.TestCase): + @staticmethod + def get_gemma_7b(): + dotenv.load_dotenv() + with OsEnviron("GROQ_LLM_MODEL", "gemma-7b-it"): + return GroqLLM.from_env() + def test_message(self): messages = [LLMMessage.user_message("what is the capital of France?")] - dotenv.load_dotenv() - with OsEnviron("GROQ_LLM_MODEL", "llama-3.2-1b-preview"): - instance = GroqLLM.from_env() + llm = self.get_gemma_7b() - result = instance.post_chat_request(LLMContext.empty(), messages) + result = llm.post_chat_request(LLMContext.empty(), messages) assert "Paris" in result.choices[0] messages.append(LLMMessage.user_message("give a famous monument of that place")) - result = instance.post_chat_request(LLMContext.empty(), messages) + result = llm.post_chat_request(LLMContext.empty(), messages) assert "Eiffel" in result.choices[0] def test_consumptions(self): messages = [LLMMessage.user_message("Hello how are you?")] - dotenv.load_dotenv() - with OsEnviron("GROQ_LLM_MODEL", "llama-3.2-1b-preview"): - instance = GroqLLM.from_env() - result = instance.post_chat_request(LLMContext.empty(), messages) + llm = self.get_gemma_7b() + result = llm.post_chat_request(LLMContext.empty(), messages) + + assert len(result.consumptions) == 12 # call, duration, 3 token kinds, 3 cost kinds and 4 groq duration + for consumption in result.consumptions: + assert consumption.kind.startswith("gemma-7b-it") + + def test_max_tokens_param(self): + llm = self.get_gemma_7b() + llm.configuration.temperature.set(0.8) + llm.configuration.max_tokens.set(7) + + messages = [LLMMessage.user_message("Hey how are you?")] + result = llm.post_chat_request(LLMContext.empty(), messages) + print(f"Predicted: {result.first_choice}") + + def test_json_mode(self): + messages = [LLMMessage.user_message("Output a JSON object with the data about RPG character.")] + llm = self.get_gemma_7b() + result = llm.post_chat_request(LLMContext.empty(), messages, response_format={"type": "json_object"}) - assert len(result.consumptions) == 12 # call, duration, 3 token kinds, 3 cost kinds and 4 groq duration - for consumption in result.consumptions: - assert consumption.kind.startswith("llama-3.2-1b-preview") + data = json.loads(result.first_choice) + assert isinstance(data, dict)