Skip to content

Commit

Permalink
Merge pull request #29 from defog-ai/jp/gemini
Browse files Browse the repository at this point in the history
Add gemini
  • Loading branch information
Muhammad18557 authored Oct 30, 2024
2 parents f1cce2c + a6d0eeb commit 26af12f
Show file tree
Hide file tree
Showing 6 changed files with 256 additions and 29 deletions.
1 change: 1 addition & 0 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,5 @@ jobs:
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
TOGETHER_API_KEY: ${{ secrets.TOGETHER_API_KEY }}
GEMINI_API_KEY: ${{ secrets.GEMINI_API_KEY }}

58 changes: 54 additions & 4 deletions defog_utils/utils_llm.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
from dataclasses import dataclass
import os
import time
from typing import Optional, List, Dict
from dataclasses import dataclass
from typing import Dict, List, Optional

import google.generativeai as genai
from anthropic import Anthropic
from openai import OpenAI
from together import Together

client_anthropic = Anthropic()
client_openai = OpenAI()
client_together = Together()
genai.configure(api_key=os.environ["GEMINI_API_KEY"])


@dataclass
Expand Down Expand Up @@ -36,7 +40,7 @@ def chat_anthropic(
sys_msg = messages[0]["content"]
messages = messages[1:]
else:
sys_msg = None
sys_msg = ""
response = client_anthropic.messages.create(
system=sys_msg,
messages=messages,
Expand Down Expand Up @@ -97,7 +101,7 @@ def chat_openai(

def chat_together(
messages: List[Dict[str, str]],
model: str = "meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo",
model: str = "meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
max_tokens: int = 4096,
temperature: float = 0.0,
stop: List[str] = [],
Expand Down Expand Up @@ -130,3 +134,49 @@ def chat_together(
response.usage.prompt_tokens,
response.usage.completion_tokens,
)


def chat_gemini(
messages: List[Dict[str, str]],
model: str = "gemini-1.5-pro",
max_tokens: int = 8192,
temperature: float = 0.0,
stop: List[str] = [],
json_mode: bool = False,
seed: int = 0,
) -> Optional[LLMResponse]:
t = time.time()
generation_config = {
"temperature": temperature,
"max_output_tokens": max_tokens,
"response_mime_type": "application/json" if json_mode else "text/plain",
"stop_sequences": stop,
# "seed": seed, # seed is not supported in the current version
}
if messages[0]["role"] == "system":
system_msg = messages[0]["content"]
messages = messages[1:]
else:
system_msg = None
final_msg = messages[-1]["content"]
messages = messages[:-1]
for msg in messages:
if msg["role"] != "user":
msg["role"] = "model"
client_gemini = genai.GenerativeModel(model, generation_config=generation_config, system_instruction=system_msg)
chat = client_gemini.start_chat(
history=messages,
)
response = chat.send_message(final_msg)
if len(response.candidates) == 0:
print("Empty response")
return None
if response.candidates[0].finish_reason.value != 1: # 1 is the finish reason for STOP
print("Max tokens reached")
return None
return LLMResponse(
content=response.candidates[0].content.parts[0].text,
time=round(time.time() - t, 3),
input_tokens=response.usage_metadata.prompt_token_count,
output_tokens=response.usage_metadata.candidates_token_count,
)
28 changes: 28 additions & 0 deletions defog_utils/utils_multi_llm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import concurrent
from typing import Callable, Dict

from .utils_llm import LLMResponse, chat_anthropic, chat_gemini, chat_openai, chat_together

def map_model_to_chat_fn(model: str) -> Callable:
"""
Returns the appropriate chat function based on the model.
"""
if model.startswith("claude"):
return chat_anthropic
if model.startswith("gemini"):
return chat_gemini
if model.startswith("gpt"):
return chat_openai
if model.startswith("meta-llama") or model.startswith("mistralai") or model.startswith("Qwen"):
return chat_together
raise ValueError(f"Unknown model: {model}")

def chat(models, messages, max_tokens=8192, temperature=0.0, stop=[], json_mode=False, seed=0) -> Dict[str, LLMResponse]:
"""
Returns the response from the LLM API for each of the models passed in.
Output format is a dictionary keyed by model name.
"""
with concurrent.futures.ThreadPoolExecutor(max_workers=len(models)) as executor:
futures = {executor.submit(map_model_to_chat_fn(model), messages, model, max_tokens, temperature, stop, json_mode, seed): model for model in models}
responses = {futures[future]: future.result() for future in concurrent.futures.as_completed(futures)}
return responses
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
anthropic==0.37.1
google-generativeai==0.8.3
numpy
openai==1.52.2
psycopg2-binary==2.9.9
Expand Down
120 changes: 95 additions & 25 deletions tests/test_utils_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@
from ..defog_utils.utils_llm import (
LLMResponse,
chat_anthropic,
chat_gemini,
chat_openai,
chat_together,
)

messages = [
messages_no_sys = [{"role": "user", "content": "Return a greeting in not more than 2 words\n"}]
messages_sql = [
{
"role": "system",
"content": "Your task is to generate SQL given a natural language question and schema of the user's database. Do not use aliases. Return only the SQL without ```.",
Expand Down Expand Up @@ -50,56 +52,112 @@
]

acceptable_sql = [
"SELECT COUNT(*) FROM orders",
"SELECT COUNT(order_id) FROM orders",
"select count(*) from orders",
"select count(order_id) from orders",
"select count(*) as total_orders from orders",
"select count(order_id) as total_orders from orders",
]

acceptable_sql_from_json = set(
[
"SELECT COUNT(order_id) as total_orders FROM orders;",
"SELECT COUNT(*) AS total_orders FROM orders;",
"SELECT COUNT(order_id) FROM orders;",
"SELECT COUNT(*) FROM orders;",
]
)
class TestChatClients(unittest.TestCase):

def check_sql(self, sql: str):
self.assertIn(sql.strip(";\n").lower(), acceptable_sql)

class TestChatClients(unittest.TestCase):
def test_chat_anthropic(self):
def test_chat_anthropic_no_sys(self):
response = chat_anthropic(
messages_no_sys,
model="claude-3-haiku-20240307",
max_tokens=10,
seed=0,
)
print(response)
self.assertIsInstance(response, LLMResponse)
self.assertIsInstance(response.content, str)
self.assertEqual(response.input_tokens, 18)
self.assertLessEqual(response.output_tokens, 10)

def test_chat_gemini_no_sys(self):
response = chat_gemini(
messages_no_sys,
model="gemini-1.5-flash",
max_tokens=10,
seed=0,
)
print(response)
self.assertIsInstance(response, LLMResponse)
self.assertIsInstance(response.content, str)
self.assertEqual(response.input_tokens, 12)
self.assertLessEqual(response.output_tokens, 10)

def test_chat_openai_no_sys(self):
response = chat_openai(
messages_no_sys,
model="gpt-4o-mini",
max_tokens=10,
seed=0,
)
print(response)
self.assertIsInstance(response, LLMResponse)
self.assertIsInstance(response.content, str)
self.assertEqual(response.input_tokens, 18)
self.assertLessEqual(response.output_tokens, 10)

def test_chat_together_no_sys(self):
response = chat_together(
messages_no_sys,
model="meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo",
max_tokens=10,
seed=0,
)
print(response)
self.assertIsInstance(response, LLMResponse)
self.assertIsInstance(response.content, str)
self.assertEqual(response.input_tokens, 46) # hidden sys prompt added I think
self.assertLessEqual(response.output_tokens, 10)

def test_chat_anthropic_sql(self):
response = chat_anthropic(
messages,
messages_sql,
model="claude-3-haiku-20240307",
max_tokens=100,
stop=[";"],
seed=0,
)
print(response)
self.assertIsInstance(response, LLMResponse)
self.assertIn(response.content, acceptable_sql)
self.check_sql(response.content)
self.assertEqual(response.input_tokens, 90) # 90 input tokens
self.assertTrue(response.output_tokens < 10) # output tokens should be < 10
self.assertTrue(response.output_tokens < 15) # output tokens should be < 15

def test_chat_openai(self):
response = chat_openai(messages, model="gpt-4o-mini", stop=[";"], seed=0)
def test_chat_openai_sql(self):
response = chat_openai(messages_sql, model="gpt-4o-mini", stop=[";"], seed=0)
print(response)
self.assertIsInstance(response, LLMResponse)
self.assertIn(response.content, acceptable_sql)
self.check_sql(response.content)
self.assertEqual(response.input_tokens, 83)
self.assertTrue(response.output_tokens < 10) # output tokens should be < 10

def test_chat_together(self):
def test_chat_together_sql(self):
response = chat_together(
messages,
messages_sql,
model="meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo",
stop=[";"],
seed=0,
)
print(response)
self.assertIsInstance(response, LLMResponse)
self.assertIn(response.content, acceptable_sql)
self.check_sql(response.content)
self.assertEqual(response.input_tokens, 108)
self.assertTrue(response.output_tokens < 10) # output tokens should be < 10

def test_chat_gemini_sql(self):
response = chat_gemini(messages_sql, model="gemini-1.5-flash", stop=[";"], seed=0)
print(response)
self.assertIsInstance(response, LLMResponse)
self.check_sql(response.content)
self.assertEqual(response.input_tokens, 86)
self.assertTrue(response.output_tokens < 10)

def test_chat_json_anthropic(self):
response = chat_anthropic(
messages_json,
Expand All @@ -111,7 +169,7 @@ def test_chat_json_anthropic(self):
print(response)
self.assertIsInstance(response, LLMResponse)
resp_dict = json.loads(response.content)
self.assertIn(resp_dict["sql"], acceptable_sql_from_json)
self.check_sql(resp_dict["sql"])
self.assertIsInstance(resp_dict["reasoning"], str)
self.assertIsInstance(response.input_tokens, int)
self.assertIsInstance(response.output_tokens, int)
Expand All @@ -123,7 +181,7 @@ def test_chat_json_openai(self):
print(response)
self.assertIsInstance(response, LLMResponse)
resp_dict = json.loads(response.content)
self.assertIn(resp_dict["sql"], acceptable_sql_from_json)
self.check_sql(resp_dict["sql"])
self.assertIsInstance(resp_dict["reasoning"], str)
self.assertIsInstance(response.input_tokens, int)
self.assertIsInstance(response.output_tokens, int)
Expand All @@ -139,7 +197,19 @@ def test_chat_json_together(self):
self.assertIsInstance(response, LLMResponse)
raw_output = response.content
resp_dict = json.loads(raw_output)
self.assertIn(resp_dict["sql"], acceptable_sql_from_json)
self.check_sql(resp_dict["sql"])
self.assertIsInstance(resp_dict["reasoning"], str)
self.assertIsInstance(response.input_tokens, int)
self.assertIsInstance(response.output_tokens, int)

def test_chat_json_gemini(self):
response = chat_gemini(
messages_json, model="gemini-1.5-flash", seed=0, json_mode=True
)
print(response)
self.assertIsInstance(response, LLMResponse)
resp_dict = json.loads(response.content)
self.check_sql(resp_dict["sql"])
self.assertIsInstance(resp_dict["reasoning"], str)
self.assertIsInstance(response.input_tokens, int)
self.assertIsInstance(response.output_tokens, int)
77 changes: 77 additions & 0 deletions tests/test_utils_multi_llm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
import unittest
from ..defog_utils.utils_multi_llm import map_model_to_chat_fn, chat
from ..defog_utils.utils_llm import LLMResponse, chat_anthropic, chat_gemini, chat_openai, chat_together

messages_sql = [
{
"role": "system",
"content": "Your task is to generate SQL given a natural language question and schema of the user's database. Do not use aliases. Return only the SQL without ```.",
},
{
"role": "user",
"content": f"""Question: What is the total number of orders?
Schema:
```sql
CREATE TABLE orders (
order_id int,
customer_id int,
employee_id int,
order_date date
);
```
""",
},
]

acceptable_sql = [
"select count(*) from orders",
"select count(order_id) from orders",
"select count(*) as total_orders from orders",
"select count(order_id) as total_orders from orders",
]

class TestChatClients(unittest.TestCase):
def check_sql(self, sql: str):
self.assertIn(sql.strip(";\n").lower(), acceptable_sql)

def test_map_model_to_chat_fn(self):
self.assertEqual(map_model_to_chat_fn("claude-3-5-sonnet-20241022"), chat_anthropic)
self.assertEqual(map_model_to_chat_fn("gemini-1.5-pro"), chat_gemini)
self.assertEqual(map_model_to_chat_fn("gpt-4o"), chat_openai)
self.assertEqual(map_model_to_chat_fn("mistralai/Mistral-7B-Instruct-v0.3"), chat_together)
self.assertEqual(map_model_to_chat_fn("meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo"), chat_together)
self.assertEqual(map_model_to_chat_fn("Qwen/Qwen2.5-72B-Instruct-Turbo"), chat_together)
with self.assertRaises(ValueError):
map_model_to_chat_fn("unknown-model")

def test_simple_chat(self):
models = ["claude-3-haiku-20240307", "gemini-1.5-flash-002", "gpt-4o-mini", "meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo"]
messages = [{"role": "user", "content": "Return a greeting in not more than 2 words\n"}]
responses = chat(models, messages, max_tokens=20, temperature=0.0, stop=[";"], json_mode=False, seed=0)
self.assertIsInstance(responses, dict)
for model in models:
self.assertIn(model, responses)
response = responses[model]
print(model, response)
self.assertIsInstance(response, LLMResponse)
self.assertIsInstance(response.content, str)
self.assertIsInstance(response.time, float)
self.assertLess(response.input_tokens, 50) # higher as default system prompt is added in together's API when none provided
self.assertLess(response.output_tokens, 20)

def test_sql_chat(self):
models = ["claude-3-haiku-20240307", "gemini-1.5-flash-002", "gpt-4o-mini", "meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo"]
responses = chat(models, messages_sql, max_tokens=20, temperature=0.0, stop=[";"], json_mode=False, seed=0)
self.assertIsInstance(responses, dict)
for model in models:
self.assertIn(model, responses)
response = responses[model]
print(model, response)
self.assertIsInstance(response, LLMResponse)
self.assertIsInstance(response.content, str)
self.assertIsInstance(response.time, float)
self.assertLess(response.input_tokens, 110)
self.assertLess(response.output_tokens, 20)

if __name__ == "__main__":
unittest.main()

0 comments on commit 26af12f

Please sign in to comment.