Skip to content

Commit

Permalink
Initial Gemini implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
aflament committed Jun 22, 2024
1 parent a5937d7 commit b1ee415
Show file tree
Hide file tree
Showing 11 changed files with 278 additions and 6 deletions.
2 changes: 1 addition & 1 deletion council/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,5 @@
from .controllers import BasicController, ControllerBase, ExecutionUnit, LLMController
from .evaluators import BasicEvaluator, EvaluatorBase, LLMEvaluator
from .filters import BasicFilter, FilterBase
from .llm import AnthropicLLM, AzureLLM, OpenAILLM
from .llm import AnthropicLLM, AzureLLM, OpenAILLM, GeminiLLM
from .runners import DoWhile, If, Parallel, ParallelFor, RunnerGenerator, RunnerPredicate, Sequential, While
5 changes: 5 additions & 0 deletions council/llm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@
from .anthropic_llm_configuration import AnthropicLLMConfiguration
from .anthropic_llm import AnthropicLLM

from .gemini_llm_configuration import GeminiLLMConfiguration
from .gemini_llm import GeminiLLM


def get_default_llm(max_retries: Optional[int] = None) -> LLMBase:
provider = read_env_str("COUNCIL_DEFAULT_LLM_PROVIDER", default=LLMProviders.OpenAI).unwrap()
Expand All @@ -37,6 +40,8 @@ def get_default_llm(max_retries: Optional[int] = None) -> LLMBase:
llm = AzureLLM.from_env()
elif provider == LLMProviders.Anthropic.lower():
llm = AnthropicLLM.from_env()
elif provider == LLMProviders.Gemini.lower():
llm = GeminiLLM.from_env()

if llm is None:
raise ValueError(f"Provider {provider} not supported by council.")
Expand Down
81 changes: 81 additions & 0 deletions council/llm/gemini_llm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
from __future__ import annotations

from typing import Any, List, Sequence, Tuple

import google.generativeai as genai # type: ignore
from council.contexts import Consumption, LLMContext
from council.llm import (
GeminiLLMConfiguration,
LLMBase,
LLMConfigObject,
LLMMessage,
LLMMessageRole,
LLMProviders,
LLMResult,
)
from google.ai.generativelanguage_v1 import HarmCategory # type: ignore
from google.generativeai.types import HarmBlockThreshold # type: ignore


class GeminiLLM(LLMBase[GeminiLLMConfiguration]):
def __init__(self, config: GeminiLLMConfiguration) -> None:
"""
Initialize a new instance.
Args:
config(GeminiLLMConfiguration): configuration for the instance
"""
super().__init__(name=f"{self.__class__.__name__}", configuration=config)
genai.configure(api_key=config.api_key.value)
self._model = genai.GenerativeModel(
config.model_name(),
safety_settings={
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.HARM_BLOCK_THRESHOLD_UNSPECIFIED,
},
)

def _post_chat_request(self, context: LLMContext, messages: Sequence[LLMMessage], **kwargs: Any) -> LLMResult:
history, last = self._to_chat_history(messages=messages)
chat = self._model.start_chat(history=history)
response = chat.send_message(last)
return LLMResult(choices=[response.text], consumptions=self.to_consumptions())

def to_consumptions(self) -> Sequence[Consumption]:
model = self._configuration.model_name()
return [
Consumption(1, "call", f"{model}"),
]

@staticmethod
def from_env() -> GeminiLLM:
"""
Helper function that create a new instance by getting the configuration from environment variables.
Returns:
GeminiLLM
"""

return GeminiLLM(GeminiLLMConfiguration.from_env())

@staticmethod
def from_config(config_object: LLMConfigObject) -> GeminiLLM:
provider = config_object.spec.provider
if not provider.is_of_kind(LLMProviders.Gemini):
raise ValueError(f"Invalid LLM provider, actual {provider}, expected {LLMProviders.Gemini}")

config = GeminiLLMConfiguration.from_spec(config_object.spec)
return GeminiLLM(config=config)

@staticmethod
def _to_chat_history(messages: Sequence[LLMMessage]) -> Tuple[List[Any], Any]:
history = []
for message in messages[:-1]:
if message.is_of_role(LLMMessageRole.System):
history.append({"role": "user", "parts": [{"text": f"System Prompt: {message.content}"}]})
history.append({"role": "model", "parts": [{"text": "Understood"}]})
elif message.is_of_role(LLMMessageRole.User):
history.append({"role": "user", "parts": [{"text": message.content}]})
elif message.is_of_role(LLMMessageRole.Assistant):
history.append({"role": "model", "parts": [{"text": message.content}]})
last = messages[-1].content
return history, last
104 changes: 104 additions & 0 deletions council/llm/gemini_llm_configuration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
from __future__ import annotations

from typing import Any, Final, Optional

from council.utils import Parameter, greater_than_validator, not_empty_validator, prefix_validator, read_env_str

from . import LLMConfigSpec, LLMConfigurationBase

_env_var_prefix: Final[str] = "GEMINI_"


def _tv(x: float) -> None:
"""
Temperature and Top_p Validators
Sampling temperature to use, between 0. and 1.
"""
if x < 0.0 or x > 1.0:
raise ValueError("must be in the range [0.0..1.0]")


class GeminiLLMConfiguration(LLMConfigurationBase):
def __init__(self, model: str, api_key: str) -> None:
"""
Initialize a new instance
Args:
api_key (str): the api key
model (str):
"""
super().__init__()
self._model = Parameter.string(name="model", required=True, value=model, validator=prefix_validator("gemini-"))
self._api_key = Parameter.string(name="api_key", required=True, value=api_key, validator=not_empty_validator)
self._temperature = Parameter.float(name="temperature", required=False, default=0.0, validator=_tv)
self._top_p = Parameter.float(name="top_p", required=False, validator=_tv)
self._top_k = Parameter.int(name="top_k", required=False, validator=greater_than_validator(0))

def model_name(self) -> str:
return self._model.unwrap()

@property
def model(self) -> Parameter[str]:
"""
Gemini model
"""
return self._model

@property
def api_key(self) -> Parameter[str]:
"""
Gemini API Key
"""
return self._api_key

@property
def temperature(self) -> Parameter[float]:
"""
Amount of randomness injected into the response.
Ranges from 0 to 1.
Use temp closer to 0 for analytical / multiple choice, and closer to 1 for creative and generative tasks.
"""
return self._temperature

@property
def top_p(self) -> Parameter[float]:
"""
Use nucleus sampling.
In nucleus sampling, we compute the cumulative distribution over all the options for each subsequent token in
decreasing probability order and cut it off once it reaches a particular probability specified by top_p.
"""
return self._top_p

@property
def top_k(self) -> Parameter[int]:
"""
Only sample from the top K options for each subsequent token.
Used to remove "long tail" low probability responses.
"""
return self._top_k

@staticmethod
def from_env() -> GeminiLLMConfiguration:
api_key = read_env_str(_env_var_prefix + "API_KEY").unwrap()
model = read_env_str(_env_var_prefix + "LLM_MODEL").unwrap()
config = GeminiLLMConfiguration(model=model, api_key=api_key)
return config

@staticmethod
def from_spec(spec: LLMConfigSpec) -> GeminiLLMConfiguration:
api_key = spec.provider.must_get_value("apiKey")
model = spec.provider.must_get_value("model")
config = GeminiLLMConfiguration(model=str(model), api_key=str(api_key))

if spec.parameters is not None:
value: Optional[Any] = spec.parameters.get("temperature", None)
if value is not None:
config.temperature.set(float(value))
value = spec.parameters.get("topP", None)
if value is not None:
config.top_p.set(float(value))
value = spec.parameters.get("topK", None)
if value is not None:
config.top_k.set(int(value))

return config
10 changes: 8 additions & 2 deletions council/llm/llm_config_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ class LLMProviders(str, Enum):
OpenAI = "openAISpec"
Azure = "azureSpec"
Anthropic = "anthropicSpec"
Gemini = "googleGeminiSpec"


class LLMProvider:
Expand Down Expand Up @@ -43,16 +44,21 @@ def from_dict(cls, values: Dict[str, Any]) -> LLMProvider:
spec = values.get(LLMProviders.Anthropic)
if spec is not None:
return LLMProvider(name, description, spec, LLMProviders.Anthropic)
spec = values.get(LLMProviders.Gemini)
if spec is not None:
return LLMProvider(name, description, spec, LLMProviders.Gemini)
raise ValueError("Unsupported model provider")

def to_dict(self) -> Dict[str, Any]:
result: Dict[str, Any] = {"name": self.name, "description": self.description}
if self.is_of_kind(LLMProviders.OpenAI):
result[LLMProviders.OpenAI] = self._specs
if self.is_of_kind(LLMProviders.Azure):
elif self.is_of_kind(LLMProviders.Azure):
result[LLMProviders.Azure] = self._specs
if self.is_of_kind(LLMProviders.Anthropic):
elif self.is_of_kind(LLMProviders.Anthropic):
result[LLMProviders.Anthropic] = self._specs
elif self.is_of_kind(LLMProviders.Gemini):
result[LLMProviders.Gemini] = self._specs
return result

def must_get_value(self, key: str) -> Any:
Expand Down
6 changes: 3 additions & 3 deletions dev-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,12 @@ types-beautifulsoup4~=4.12.0.7
# Lint
black==24.4.2
mypy==1.10.0
ruff==0.4.8
ruff==0.4.10
pylint==3.2.3
isort==5.13.2

# Test
ipykernel==6.26.0
nbconvert==7.11.0
nbformat==5.9.2
nbconvert==7.16.4
nbformat==5.10.4
pytest==7.4.4
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ tiktoken~=0.7.0

# LLMs
anthropic~=0.20.0
google-generativeai==0.7.0

# Skills
## Google
Expand Down
17 changes: 17 additions & 0 deletions tests/data/gemini-llmodel.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
kind: LLMConfig
version: 0.1
metadata:
name: an-gemini-deployed-model
labels:
provider: Google
spec:
description: "Model used to do RST"
provider:
name: CML-Gemini
googleGeminiSpec:
model: gemini-pro
apiKey:
fromEnvVar: GEMINI_API_KEY
parameters:
temperature: 0.5
topK: 8
35 changes: 35 additions & 0 deletions tests/integration/llm/test_gemini_llm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import unittest

import dotenv
from council import LLMContext
from council.llm import LLMMessage, GeminiLLM
from council.utils import OsEnviron


class TestAnthropicLLM(unittest.TestCase):
def test_completion(self):
messages = [LLMMessage.user_message("what is the capital of France?")]
dotenv.load_dotenv()
with OsEnviron("GEMINI_LLM_MODEL", "gemini-1.5-flash"):
instance = GeminiLLM.from_env()
context = LLMContext.empty()
result = instance.post_chat_request(context, messages)

assert "Paris" in result.choices[0]

def test_message(self):
messages = [LLMMessage.user_message("what is the capital of France?")]
dotenv.load_dotenv()
with OsEnviron("GEMINI_LLM_MODEL", "gemini-1.0-pro"):
instance = GeminiLLM.from_env()
context = LLMContext.empty()
result = instance.post_chat_request(context, messages)

assert "Paris" in result.choices[0]

messages.append(LLMMessage.user_message("give a famous monument of that place"))
with OsEnviron("GEMINI_LLM_MODEL", "gemini-1.5-pro"):
instance = GeminiLLM.from_env()
result = instance.post_chat_request(context, messages)

assert "Eiffel" in result.choices[0]
1 change: 1 addition & 0 deletions tests/unit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ class LLModels:
OpenAI: str = "openai-llmodel.yaml"
Anthropic: str = "anthropic-llmodel.yaml"
AzureWithFallback: str = "azure-with-fallback-llmodel.yaml"
Gemini: str = "gemini-llmodel.yaml"


class LLMPrompts:
Expand Down
22 changes: 22 additions & 0 deletions tests/unit/llm/test_gemini_llm_configuration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import unittest
from council.llm import GeminiLLMConfiguration
from council.utils import OsEnviron, ParameterValueException


class TestGeminiLLMConfiguration(unittest.TestCase):
def test_model_override(self):
with OsEnviron("GEMINI_API_KEY", "some-key"), OsEnviron("GEMINI_LLM_MODEL", "gemini-something"):
config = GeminiLLMConfiguration.from_env()
self.assertEqual("some-key", config.api_key.value)
self.assertEqual("gemini-something", config.model.value)

def test_default(self):
config = GeminiLLMConfiguration(model="gemini-something", api_key="some-key")
self.assertEqual(0.0, config.temperature.value)
self.assertTrue(config.top_p.is_none())

def test_invalid(self):
with self.assertRaises(ParameterValueException):
_ = GeminiLLMConfiguration(model="a-gemini-model", api_key="sk-key")
with self.assertRaises(ParameterValueException):
_ = GeminiLLMConfiguration(model="gemini-model", api_key="")

0 comments on commit b1ee415

Please sign in to comment.