From b1ee4152dffbeb701df1f17896b349d3f943eb9b Mon Sep 17 00:00:00 2001 From: Arnaud Flament Date: Fri, 21 Jun 2024 18:31:35 -0700 Subject: [PATCH] Initial Gemini implementation --- council/__init__.py | 2 +- council/llm/__init__.py | 5 + council/llm/gemini_llm.py | 81 ++++++++++++++ council/llm/gemini_llm_configuration.py | 104 ++++++++++++++++++ council/llm/llm_config_object.py | 10 +- dev-requirements.txt | 6 +- requirements.txt | 1 + tests/data/gemini-llmodel.yaml | 17 +++ tests/integration/llm/test_gemini_llm.py | 35 ++++++ tests/unit/__init__.py | 1 + .../unit/llm/test_gemini_llm_configuration.py | 22 ++++ 11 files changed, 278 insertions(+), 6 deletions(-) create mode 100644 council/llm/gemini_llm.py create mode 100644 council/llm/gemini_llm_configuration.py create mode 100644 tests/data/gemini-llmodel.yaml create mode 100644 tests/integration/llm/test_gemini_llm.py create mode 100644 tests/unit/llm/test_gemini_llm_configuration.py diff --git a/council/__init__.py b/council/__init__.py index ee6a6fe6..e0d49aac 100644 --- a/council/__init__.py +++ b/council/__init__.py @@ -6,5 +6,5 @@ from .controllers import BasicController, ControllerBase, ExecutionUnit, LLMController from .evaluators import BasicEvaluator, EvaluatorBase, LLMEvaluator from .filters import BasicFilter, FilterBase -from .llm import AnthropicLLM, AzureLLM, OpenAILLM +from .llm import AnthropicLLM, AzureLLM, OpenAILLM, GeminiLLM from .runners import DoWhile, If, Parallel, ParallelFor, RunnerGenerator, RunnerPredicate, Sequential, While diff --git a/council/llm/__init__.py b/council/llm/__init__.py index f33bf214..4dcef28c 100644 --- a/council/llm/__init__.py +++ b/council/llm/__init__.py @@ -25,6 +25,9 @@ from .anthropic_llm_configuration import AnthropicLLMConfiguration from .anthropic_llm import AnthropicLLM +from .gemini_llm_configuration import GeminiLLMConfiguration +from .gemini_llm import GeminiLLM + def get_default_llm(max_retries: Optional[int] = None) -> LLMBase: provider = read_env_str("COUNCIL_DEFAULT_LLM_PROVIDER", default=LLMProviders.OpenAI).unwrap() @@ -37,6 +40,8 @@ def get_default_llm(max_retries: Optional[int] = None) -> LLMBase: llm = AzureLLM.from_env() elif provider == LLMProviders.Anthropic.lower(): llm = AnthropicLLM.from_env() + elif provider == LLMProviders.Gemini.lower(): + llm = GeminiLLM.from_env() if llm is None: raise ValueError(f"Provider {provider} not supported by council.") diff --git a/council/llm/gemini_llm.py b/council/llm/gemini_llm.py new file mode 100644 index 00000000..7465779b --- /dev/null +++ b/council/llm/gemini_llm.py @@ -0,0 +1,81 @@ +from __future__ import annotations + +from typing import Any, List, Sequence, Tuple + +import google.generativeai as genai # type: ignore +from council.contexts import Consumption, LLMContext +from council.llm import ( + GeminiLLMConfiguration, + LLMBase, + LLMConfigObject, + LLMMessage, + LLMMessageRole, + LLMProviders, + LLMResult, +) +from google.ai.generativelanguage_v1 import HarmCategory # type: ignore +from google.generativeai.types import HarmBlockThreshold # type: ignore + + +class GeminiLLM(LLMBase[GeminiLLMConfiguration]): + def __init__(self, config: GeminiLLMConfiguration) -> None: + """ + Initialize a new instance. + + Args: + config(GeminiLLMConfiguration): configuration for the instance + """ + super().__init__(name=f"{self.__class__.__name__}", configuration=config) + genai.configure(api_key=config.api_key.value) + self._model = genai.GenerativeModel( + config.model_name(), + safety_settings={ + HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.HARM_BLOCK_THRESHOLD_UNSPECIFIED, + }, + ) + + def _post_chat_request(self, context: LLMContext, messages: Sequence[LLMMessage], **kwargs: Any) -> LLMResult: + history, last = self._to_chat_history(messages=messages) + chat = self._model.start_chat(history=history) + response = chat.send_message(last) + return LLMResult(choices=[response.text], consumptions=self.to_consumptions()) + + def to_consumptions(self) -> Sequence[Consumption]: + model = self._configuration.model_name() + return [ + Consumption(1, "call", f"{model}"), + ] + + @staticmethod + def from_env() -> GeminiLLM: + """ + Helper function that create a new instance by getting the configuration from environment variables. + + Returns: + GeminiLLM + """ + + return GeminiLLM(GeminiLLMConfiguration.from_env()) + + @staticmethod + def from_config(config_object: LLMConfigObject) -> GeminiLLM: + provider = config_object.spec.provider + if not provider.is_of_kind(LLMProviders.Gemini): + raise ValueError(f"Invalid LLM provider, actual {provider}, expected {LLMProviders.Gemini}") + + config = GeminiLLMConfiguration.from_spec(config_object.spec) + return GeminiLLM(config=config) + + @staticmethod + def _to_chat_history(messages: Sequence[LLMMessage]) -> Tuple[List[Any], Any]: + history = [] + for message in messages[:-1]: + if message.is_of_role(LLMMessageRole.System): + history.append({"role": "user", "parts": [{"text": f"System Prompt: {message.content}"}]}) + history.append({"role": "model", "parts": [{"text": "Understood"}]}) + elif message.is_of_role(LLMMessageRole.User): + history.append({"role": "user", "parts": [{"text": message.content}]}) + elif message.is_of_role(LLMMessageRole.Assistant): + history.append({"role": "model", "parts": [{"text": message.content}]}) + last = messages[-1].content + return history, last diff --git a/council/llm/gemini_llm_configuration.py b/council/llm/gemini_llm_configuration.py new file mode 100644 index 00000000..c578e9a7 --- /dev/null +++ b/council/llm/gemini_llm_configuration.py @@ -0,0 +1,104 @@ +from __future__ import annotations + +from typing import Any, Final, Optional + +from council.utils import Parameter, greater_than_validator, not_empty_validator, prefix_validator, read_env_str + +from . import LLMConfigSpec, LLMConfigurationBase + +_env_var_prefix: Final[str] = "GEMINI_" + + +def _tv(x: float) -> None: + """ + Temperature and Top_p Validators + Sampling temperature to use, between 0. and 1. + """ + if x < 0.0 or x > 1.0: + raise ValueError("must be in the range [0.0..1.0]") + + +class GeminiLLMConfiguration(LLMConfigurationBase): + def __init__(self, model: str, api_key: str) -> None: + """ + Initialize a new instance + + Args: + api_key (str): the api key + model (str): + """ + super().__init__() + self._model = Parameter.string(name="model", required=True, value=model, validator=prefix_validator("gemini-")) + self._api_key = Parameter.string(name="api_key", required=True, value=api_key, validator=not_empty_validator) + self._temperature = Parameter.float(name="temperature", required=False, default=0.0, validator=_tv) + self._top_p = Parameter.float(name="top_p", required=False, validator=_tv) + self._top_k = Parameter.int(name="top_k", required=False, validator=greater_than_validator(0)) + + def model_name(self) -> str: + return self._model.unwrap() + + @property + def model(self) -> Parameter[str]: + """ + Gemini model + """ + return self._model + + @property + def api_key(self) -> Parameter[str]: + """ + Gemini API Key + """ + return self._api_key + + @property + def temperature(self) -> Parameter[float]: + """ + Amount of randomness injected into the response. + Ranges from 0 to 1. + Use temp closer to 0 for analytical / multiple choice, and closer to 1 for creative and generative tasks. + """ + return self._temperature + + @property + def top_p(self) -> Parameter[float]: + """ + Use nucleus sampling. + In nucleus sampling, we compute the cumulative distribution over all the options for each subsequent token in + decreasing probability order and cut it off once it reaches a particular probability specified by top_p. + """ + return self._top_p + + @property + def top_k(self) -> Parameter[int]: + """ + Only sample from the top K options for each subsequent token. + Used to remove "long tail" low probability responses. + """ + return self._top_k + + @staticmethod + def from_env() -> GeminiLLMConfiguration: + api_key = read_env_str(_env_var_prefix + "API_KEY").unwrap() + model = read_env_str(_env_var_prefix + "LLM_MODEL").unwrap() + config = GeminiLLMConfiguration(model=model, api_key=api_key) + return config + + @staticmethod + def from_spec(spec: LLMConfigSpec) -> GeminiLLMConfiguration: + api_key = spec.provider.must_get_value("apiKey") + model = spec.provider.must_get_value("model") + config = GeminiLLMConfiguration(model=str(model), api_key=str(api_key)) + + if spec.parameters is not None: + value: Optional[Any] = spec.parameters.get("temperature", None) + if value is not None: + config.temperature.set(float(value)) + value = spec.parameters.get("topP", None) + if value is not None: + config.top_p.set(float(value)) + value = spec.parameters.get("topK", None) + if value is not None: + config.top_k.set(int(value)) + + return config diff --git a/council/llm/llm_config_object.py b/council/llm/llm_config_object.py index 94aff1c8..a4ff1668 100644 --- a/council/llm/llm_config_object.py +++ b/council/llm/llm_config_object.py @@ -13,6 +13,7 @@ class LLMProviders(str, Enum): OpenAI = "openAISpec" Azure = "azureSpec" Anthropic = "anthropicSpec" + Gemini = "googleGeminiSpec" class LLMProvider: @@ -43,16 +44,21 @@ def from_dict(cls, values: Dict[str, Any]) -> LLMProvider: spec = values.get(LLMProviders.Anthropic) if spec is not None: return LLMProvider(name, description, spec, LLMProviders.Anthropic) + spec = values.get(LLMProviders.Gemini) + if spec is not None: + return LLMProvider(name, description, spec, LLMProviders.Gemini) raise ValueError("Unsupported model provider") def to_dict(self) -> Dict[str, Any]: result: Dict[str, Any] = {"name": self.name, "description": self.description} if self.is_of_kind(LLMProviders.OpenAI): result[LLMProviders.OpenAI] = self._specs - if self.is_of_kind(LLMProviders.Azure): + elif self.is_of_kind(LLMProviders.Azure): result[LLMProviders.Azure] = self._specs - if self.is_of_kind(LLMProviders.Anthropic): + elif self.is_of_kind(LLMProviders.Anthropic): result[LLMProviders.Anthropic] = self._specs + elif self.is_of_kind(LLMProviders.Gemini): + result[LLMProviders.Gemini] = self._specs return result def must_get_value(self, key: str) -> Any: diff --git a/dev-requirements.txt b/dev-requirements.txt index b6d9a6f6..5f48f0e0 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -11,12 +11,12 @@ types-beautifulsoup4~=4.12.0.7 # Lint black==24.4.2 mypy==1.10.0 -ruff==0.4.8 +ruff==0.4.10 pylint==3.2.3 isort==5.13.2 # Test ipykernel==6.26.0 -nbconvert==7.11.0 -nbformat==5.9.2 +nbconvert==7.16.4 +nbformat==5.10.4 pytest==7.4.4 diff --git a/requirements.txt b/requirements.txt index c08365ff..69be0f92 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,6 +8,7 @@ tiktoken~=0.7.0 # LLMs anthropic~=0.20.0 +google-generativeai==0.7.0 # Skills ## Google diff --git a/tests/data/gemini-llmodel.yaml b/tests/data/gemini-llmodel.yaml new file mode 100644 index 00000000..eaaee4a0 --- /dev/null +++ b/tests/data/gemini-llmodel.yaml @@ -0,0 +1,17 @@ +kind: LLMConfig +version: 0.1 +metadata: + name: an-gemini-deployed-model + labels: + provider: Google +spec: + description: "Model used to do RST" + provider: + name: CML-Gemini + googleGeminiSpec: + model: gemini-pro + apiKey: + fromEnvVar: GEMINI_API_KEY + parameters: + temperature: 0.5 + topK: 8 diff --git a/tests/integration/llm/test_gemini_llm.py b/tests/integration/llm/test_gemini_llm.py new file mode 100644 index 00000000..745f0394 --- /dev/null +++ b/tests/integration/llm/test_gemini_llm.py @@ -0,0 +1,35 @@ +import unittest + +import dotenv +from council import LLMContext +from council.llm import LLMMessage, GeminiLLM +from council.utils import OsEnviron + + +class TestAnthropicLLM(unittest.TestCase): + def test_completion(self): + messages = [LLMMessage.user_message("what is the capital of France?")] + dotenv.load_dotenv() + with OsEnviron("GEMINI_LLM_MODEL", "gemini-1.5-flash"): + instance = GeminiLLM.from_env() + context = LLMContext.empty() + result = instance.post_chat_request(context, messages) + + assert "Paris" in result.choices[0] + + def test_message(self): + messages = [LLMMessage.user_message("what is the capital of France?")] + dotenv.load_dotenv() + with OsEnviron("GEMINI_LLM_MODEL", "gemini-1.0-pro"): + instance = GeminiLLM.from_env() + context = LLMContext.empty() + result = instance.post_chat_request(context, messages) + + assert "Paris" in result.choices[0] + + messages.append(LLMMessage.user_message("give a famous monument of that place")) + with OsEnviron("GEMINI_LLM_MODEL", "gemini-1.5-pro"): + instance = GeminiLLM.from_env() + result = instance.post_chat_request(context, messages) + + assert "Eiffel" in result.choices[0] diff --git a/tests/unit/__init__.py b/tests/unit/__init__.py index b2b4ca4c..14cadbe6 100644 --- a/tests/unit/__init__.py +++ b/tests/unit/__init__.py @@ -6,6 +6,7 @@ class LLModels: OpenAI: str = "openai-llmodel.yaml" Anthropic: str = "anthropic-llmodel.yaml" AzureWithFallback: str = "azure-with-fallback-llmodel.yaml" + Gemini: str = "gemini-llmodel.yaml" class LLMPrompts: diff --git a/tests/unit/llm/test_gemini_llm_configuration.py b/tests/unit/llm/test_gemini_llm_configuration.py new file mode 100644 index 00000000..1d4a22b1 --- /dev/null +++ b/tests/unit/llm/test_gemini_llm_configuration.py @@ -0,0 +1,22 @@ +import unittest +from council.llm import GeminiLLMConfiguration +from council.utils import OsEnviron, ParameterValueException + + +class TestGeminiLLMConfiguration(unittest.TestCase): + def test_model_override(self): + with OsEnviron("GEMINI_API_KEY", "some-key"), OsEnviron("GEMINI_LLM_MODEL", "gemini-something"): + config = GeminiLLMConfiguration.from_env() + self.assertEqual("some-key", config.api_key.value) + self.assertEqual("gemini-something", config.model.value) + + def test_default(self): + config = GeminiLLMConfiguration(model="gemini-something", api_key="some-key") + self.assertEqual(0.0, config.temperature.value) + self.assertTrue(config.top_p.is_none()) + + def test_invalid(self): + with self.assertRaises(ParameterValueException): + _ = GeminiLLMConfiguration(model="a-gemini-model", api_key="sk-key") + with self.assertRaises(ParameterValueException): + _ = GeminiLLMConfiguration(model="gemini-model", api_key="")