-
Notifications
You must be signed in to change notification settings - Fork 35
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Initial Gemini implementation (#149)
* Initial Gemini implementation * Fix * Update requirements.txt * Add doc * Update requirements.txt
- Loading branch information
Showing
14 changed files
with
312 additions
and
11 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
from __future__ import annotations | ||
|
||
from typing import Any, List, Sequence, Tuple | ||
|
||
import google.generativeai as genai # type: ignore | ||
from council.contexts import Consumption, LLMContext | ||
from council.llm import ( | ||
GeminiLLMConfiguration, | ||
LLMBase, | ||
LLMConfigObject, | ||
LLMMessage, | ||
LLMMessageRole, | ||
LLMProviders, | ||
LLMResult, | ||
) | ||
from google.ai.generativelanguage_v1 import HarmCategory # type: ignore | ||
from google.generativeai.types import HarmBlockThreshold # type: ignore | ||
|
||
|
||
class GeminiLLM(LLMBase[GeminiLLMConfiguration]): | ||
def __init__(self, config: GeminiLLMConfiguration) -> None: | ||
""" | ||
Initialize a new instance. | ||
Args: | ||
config(GeminiLLMConfiguration): configuration for the instance | ||
""" | ||
super().__init__(name=f"{self.__class__.__name__}", configuration=config) | ||
genai.configure(api_key=config.api_key.value) | ||
self._model = genai.GenerativeModel( | ||
config.model_name(), | ||
safety_settings={ | ||
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.HARM_BLOCK_THRESHOLD_UNSPECIFIED, | ||
}, | ||
) | ||
|
||
def _post_chat_request(self, context: LLMContext, messages: Sequence[LLMMessage], **kwargs: Any) -> LLMResult: | ||
history, last = self._to_chat_history(messages=messages) | ||
chat = self._model.start_chat(history=history) | ||
response = chat.send_message(last) | ||
return LLMResult(choices=[response.text], consumptions=self.to_consumptions()) | ||
|
||
def to_consumptions(self) -> Sequence[Consumption]: | ||
model = self._configuration.model_name() | ||
return [ | ||
Consumption(1, "call", f"{model}"), | ||
] | ||
|
||
@staticmethod | ||
def from_env() -> GeminiLLM: | ||
""" | ||
Helper function that create a new instance by getting the configuration from environment variables. | ||
Returns: | ||
GeminiLLM | ||
""" | ||
|
||
return GeminiLLM(GeminiLLMConfiguration.from_env()) | ||
|
||
@staticmethod | ||
def from_config(config_object: LLMConfigObject) -> GeminiLLM: | ||
provider = config_object.spec.provider | ||
if not provider.is_of_kind(LLMProviders.Gemini): | ||
raise ValueError(f"Invalid LLM provider, actual {provider}, expected {LLMProviders.Gemini}") | ||
|
||
config = GeminiLLMConfiguration.from_spec(config_object.spec) | ||
return GeminiLLM(config=config) | ||
|
||
@staticmethod | ||
def _to_chat_history(messages: Sequence[LLMMessage]) -> Tuple[List[Any], Any]: | ||
history = [] | ||
for message in messages[:-1]: | ||
if message.is_of_role(LLMMessageRole.System): | ||
history.append({"role": "user", "parts": [{"text": f"System Prompt: {message.content}"}]}) | ||
history.append({"role": "model", "parts": [{"text": "Understood"}]}) | ||
elif message.is_of_role(LLMMessageRole.User): | ||
history.append({"role": "user", "parts": [{"text": message.content}]}) | ||
elif message.is_of_role(LLMMessageRole.Assistant): | ||
history.append({"role": "model", "parts": [{"text": message.content}]}) | ||
last = messages[-1].content | ||
return history, last |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,104 @@ | ||
from __future__ import annotations | ||
|
||
from typing import Any, Final, Optional | ||
|
||
from council.utils import Parameter, greater_than_validator, not_empty_validator, prefix_validator, read_env_str | ||
|
||
from . import LLMConfigSpec, LLMConfigurationBase | ||
|
||
_env_var_prefix: Final[str] = "GEMINI_" | ||
|
||
|
||
def _tv(x: float) -> None: | ||
""" | ||
Temperature and Top_p Validators | ||
Sampling temperature to use, between 0. and 1. | ||
""" | ||
if x < 0.0 or x > 1.0: | ||
raise ValueError("must be in the range [0.0..1.0]") | ||
|
||
|
||
class GeminiLLMConfiguration(LLMConfigurationBase): | ||
def __init__(self, model: str, api_key: str) -> None: | ||
""" | ||
Initialize a new instance | ||
Args: | ||
api_key (str): the api key | ||
model (str): | ||
""" | ||
super().__init__() | ||
self._model = Parameter.string(name="model", required=True, value=model, validator=prefix_validator("gemini-")) | ||
self._api_key = Parameter.string(name="api_key", required=True, value=api_key, validator=not_empty_validator) | ||
self._temperature = Parameter.float(name="temperature", required=False, default=0.0, validator=_tv) | ||
self._top_p = Parameter.float(name="top_p", required=False, validator=_tv) | ||
self._top_k = Parameter.int(name="top_k", required=False, validator=greater_than_validator(0)) | ||
|
||
def model_name(self) -> str: | ||
return self._model.unwrap() | ||
|
||
@property | ||
def model(self) -> Parameter[str]: | ||
""" | ||
Gemini model | ||
""" | ||
return self._model | ||
|
||
@property | ||
def api_key(self) -> Parameter[str]: | ||
""" | ||
Gemini API Key | ||
""" | ||
return self._api_key | ||
|
||
@property | ||
def temperature(self) -> Parameter[float]: | ||
""" | ||
Amount of randomness injected into the response. | ||
Ranges from 0 to 1. | ||
Use temp closer to 0 for analytical / multiple choice, and closer to 1 for creative and generative tasks. | ||
""" | ||
return self._temperature | ||
|
||
@property | ||
def top_p(self) -> Parameter[float]: | ||
""" | ||
Use nucleus sampling. | ||
In nucleus sampling, we compute the cumulative distribution over all the options for each subsequent token in | ||
decreasing probability order and cut it off once it reaches a particular probability specified by top_p. | ||
""" | ||
return self._top_p | ||
|
||
@property | ||
def top_k(self) -> Parameter[int]: | ||
""" | ||
Only sample from the top K options for each subsequent token. | ||
Used to remove "long tail" low probability responses. | ||
""" | ||
return self._top_k | ||
|
||
@staticmethod | ||
def from_env() -> GeminiLLMConfiguration: | ||
api_key = read_env_str(_env_var_prefix + "API_KEY").unwrap() | ||
model = read_env_str(_env_var_prefix + "LLM_MODEL").unwrap() | ||
config = GeminiLLMConfiguration(model=model, api_key=api_key) | ||
return config | ||
|
||
@staticmethod | ||
def from_spec(spec: LLMConfigSpec) -> GeminiLLMConfiguration: | ||
api_key = spec.provider.must_get_value("apiKey") | ||
model = spec.provider.must_get_value("model") | ||
config = GeminiLLMConfiguration(model=str(model), api_key=str(api_key)) | ||
|
||
if spec.parameters is not None: | ||
value: Optional[Any] = spec.parameters.get("temperature", None) | ||
if value is not None: | ||
config.temperature.set(float(value)) | ||
value = spec.parameters.get("topP", None) | ||
if value is not None: | ||
config.top_p.set(float(value)) | ||
value = spec.parameters.get("topK", None) | ||
if value is not None: | ||
config.top_k.set(int(value)) | ||
|
||
return config |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
# AnthropicLLM | ||
|
||
```{eval-rst} | ||
.. autoclasstree:: council.llm.GeminiLLM | ||
:full: | ||
:namespace: council | ||
``` | ||
|
||
```{eval-rst} | ||
.. autoclass:: council.llm.GeminiLLM | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
# AnthropicLLMConfiguration | ||
|
||
```{eval-rst} | ||
.. autoclass:: council.llm.GeminiLLMConfiguration | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
kind: LLMConfig | ||
version: 0.1 | ||
metadata: | ||
name: an-gemini-deployed-model | ||
labels: | ||
provider: Google | ||
spec: | ||
description: "Model used to do RST" | ||
provider: | ||
name: CML-Gemini | ||
googleGeminiSpec: | ||
model: gemini-pro | ||
apiKey: | ||
fromEnvVar: GEMINI_API_KEY | ||
parameters: | ||
temperature: 0.5 | ||
topK: 8 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
import unittest | ||
|
||
import dotenv | ||
from council import LLMContext | ||
from council.llm import LLMMessage, GeminiLLM | ||
from council.utils import OsEnviron | ||
|
||
|
||
class TestAnthropicLLM(unittest.TestCase): | ||
def test_completion(self): | ||
messages = [LLMMessage.user_message("what is the capital of France?")] | ||
dotenv.load_dotenv() | ||
with OsEnviron("GEMINI_LLM_MODEL", "gemini-1.5-flash"): | ||
instance = GeminiLLM.from_env() | ||
context = LLMContext.empty() | ||
result = instance.post_chat_request(context, messages) | ||
|
||
assert "Paris" in result.choices[0] | ||
|
||
def test_message(self): | ||
messages = [LLMMessage.user_message("what is the capital of France?")] | ||
dotenv.load_dotenv() | ||
with OsEnviron("GEMINI_LLM_MODEL", "gemini-1.0-pro"): | ||
instance = GeminiLLM.from_env() | ||
context = LLMContext.empty() | ||
result = instance.post_chat_request(context, messages) | ||
|
||
assert "Paris" in result.choices[0] | ||
|
||
messages.append(LLMMessage.user_message("give a famous monument of that place")) | ||
with OsEnviron("GEMINI_LLM_MODEL", "gemini-1.5-pro"): | ||
instance = GeminiLLM.from_env() | ||
result = instance.post_chat_request(context, messages) | ||
|
||
assert "Eiffel" in result.choices[0] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
import unittest | ||
from council.llm import GeminiLLMConfiguration | ||
from council.utils import OsEnviron, ParameterValueException | ||
|
||
|
||
class TestGeminiLLMConfiguration(unittest.TestCase): | ||
def test_model_override(self): | ||
with OsEnviron("GEMINI_API_KEY", "some-key"), OsEnviron("GEMINI_LLM_MODEL", "gemini-something"): | ||
config = GeminiLLMConfiguration.from_env() | ||
self.assertEqual("some-key", config.api_key.value) | ||
self.assertEqual("gemini-something", config.model.value) | ||
|
||
def test_default(self): | ||
config = GeminiLLMConfiguration(model="gemini-something", api_key="some-key") | ||
self.assertEqual(0.0, config.temperature.value) | ||
self.assertTrue(config.top_p.is_none()) | ||
|
||
def test_invalid(self): | ||
with self.assertRaises(ParameterValueException): | ||
_ = GeminiLLMConfiguration(model="a-gemini-model", api_key="sk-key") | ||
with self.assertRaises(ParameterValueException): | ||
_ = GeminiLLMConfiguration(model="gemini-model", api_key="") |
Oops, something went wrong.