diff --git a/defog_utils/utils_llm.py b/defog_utils/utils_llm.py index 1bda4e1..63769b7 100644 --- a/defog_utils/utils_llm.py +++ b/defog_utils/utils_llm.py @@ -1,6 +1,9 @@ -from dataclasses import dataclass +import os import time -from typing import Optional, List, Dict +from dataclasses import dataclass +from typing import Dict, List, Optional + +import google.generativeai as genai from anthropic import Anthropic from openai import OpenAI from together import Together @@ -8,6 +11,7 @@ client_anthropic = Anthropic() client_openai = OpenAI() client_together = Together() +genai.configure(api_key=os.environ["GEMINI_API_KEY"]) @dataclass @@ -130,3 +134,45 @@ def chat_together( response.usage.prompt_tokens, response.usage.completion_tokens, ) + + +def chat_gemini( + messages: List[Dict[str, str]], + model: str = "gemini-1.5-pro", + max_tokens: int = 8192, + temperature: float = 0.0, + stop: List[str] = [], + json_mode: bool = False, + seed: int = 0, +) -> Optional[LLMResponse]: + t = time.time() + generation_config = { + "temperature": temperature, + "max_output_tokens": max_tokens, + "response_mime_type": "application/json" if json_mode else "text/plain", + "stop_sequences": stop, + # "seed": seed, # seed is not supported in the current version + } + if messages[0]["role"] == "system": + system_msg = messages[0]["content"] + messages = messages[1:] + final_msg = messages[-1]["content"] + messages = messages[:-1] + for msg in messages: + if msg["role"] != "user": + msg["role"] = "model" + client_gemini = genai.GenerativeModel(model, generation_config=generation_config, system_instruction=system_msg) + chat = client_gemini.start_chat( + history=messages, + ) + response = chat.send_message(final_msg) + if len(response.candidates) == 0: + print("Empty response") + return None + print(response.candidates[0].finish_reason) + return LLMResponse( + content=response.candidates[0].content.parts[0].text, + time=round(time.time() - t, 3), + input_tokens=response.usage_metadata.prompt_token_count, + output_tokens=response.usage_metadata.candidates_token_count, + ) diff --git a/requirements.txt b/requirements.txt index ee803eb..97d8cd5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,5 @@ anthropic==0.37.1 +google-generativeai==0.8.3 numpy openai==1.52.2 psycopg2-binary==2.9.9 diff --git a/tests/test_utils_llm.py b/tests/test_utils_llm.py index ac0c302..eb019d0 100644 --- a/tests/test_utils_llm.py +++ b/tests/test_utils_llm.py @@ -3,6 +3,7 @@ from ..defog_utils.utils_llm import ( LLMResponse, chat_anthropic, + chat_gemini, chat_openai, chat_together, ) @@ -50,21 +51,17 @@ ] acceptable_sql = [ - "SELECT COUNT(*) FROM orders", - "SELECT COUNT(order_id) FROM orders", + "select count(*) from orders", + "select count(order_id) from orders", + "select count(*) as total_orders from orders", + "select count(order_id) as total_orders from orders", ] -acceptable_sql_from_json = set( - [ - "SELECT COUNT(order_id) as total_orders FROM orders;", - "SELECT COUNT(*) AS total_orders FROM orders;", - "SELECT COUNT(order_id) FROM orders;", - "SELECT COUNT(*) FROM orders;", - ] -) +class TestChatClients(unittest.TestCase): + def check_sql(self, sql: str): + self.assertIn(sql.strip(";\n").lower(), acceptable_sql) -class TestChatClients(unittest.TestCase): def test_chat_anthropic(self): response = chat_anthropic( messages, @@ -75,7 +72,7 @@ def test_chat_anthropic(self): ) print(response) self.assertIsInstance(response, LLMResponse) - self.assertIn(response.content, acceptable_sql) + self.check_sql(response.content) self.assertEqual(response.input_tokens, 90) # 90 input tokens self.assertTrue(response.output_tokens < 10) # output tokens should be < 10 @@ -83,7 +80,7 @@ def test_chat_openai(self): response = chat_openai(messages, model="gpt-4o-mini", stop=[";"], seed=0) print(response) self.assertIsInstance(response, LLMResponse) - self.assertIn(response.content, acceptable_sql) + self.check_sql(response.content) self.assertEqual(response.input_tokens, 83) self.assertTrue(response.output_tokens < 10) # output tokens should be < 10 @@ -96,10 +93,18 @@ def test_chat_together(self): ) print(response) self.assertIsInstance(response, LLMResponse) - self.assertIn(response.content, acceptable_sql) + self.check_sql(response.content) self.assertEqual(response.input_tokens, 108) self.assertTrue(response.output_tokens < 10) # output tokens should be < 10 + def test_chat_gemini(self): + response = chat_gemini(messages, model="gemini-1.5-flash", stop=[";"], seed=0) + print(response) + self.assertIsInstance(response, LLMResponse) + self.check_sql(response.content) + self.assertEqual(response.input_tokens, 86) + self.assertTrue(response.output_tokens < 10) + def test_chat_json_anthropic(self): response = chat_anthropic( messages_json, @@ -111,7 +116,7 @@ def test_chat_json_anthropic(self): print(response) self.assertIsInstance(response, LLMResponse) resp_dict = json.loads(response.content) - self.assertIn(resp_dict["sql"], acceptable_sql_from_json) + self.check_sql(resp_dict["sql"]) self.assertIsInstance(resp_dict["reasoning"], str) self.assertIsInstance(response.input_tokens, int) self.assertIsInstance(response.output_tokens, int) @@ -123,7 +128,7 @@ def test_chat_json_openai(self): print(response) self.assertIsInstance(response, LLMResponse) resp_dict = json.loads(response.content) - self.assertIn(resp_dict["sql"], acceptable_sql_from_json) + self.check_sql(resp_dict["sql"]) self.assertIsInstance(resp_dict["reasoning"], str) self.assertIsInstance(response.input_tokens, int) self.assertIsInstance(response.output_tokens, int) @@ -139,7 +144,19 @@ def test_chat_json_together(self): self.assertIsInstance(response, LLMResponse) raw_output = response.content resp_dict = json.loads(raw_output) - self.assertIn(resp_dict["sql"], acceptable_sql_from_json) + self.check_sql(resp_dict["sql"]) + self.assertIsInstance(resp_dict["reasoning"], str) + self.assertIsInstance(response.input_tokens, int) + self.assertIsInstance(response.output_tokens, int) + + def test_chat_json_gemini(self): + response = chat_gemini( + messages_json, model="gemini-1.5-flash", seed=0, json_mode=True + ) + print(response) + self.assertIsInstance(response, LLMResponse) + resp_dict = json.loads(response.content) + self.check_sql(resp_dict["sql"]) self.assertIsInstance(resp_dict["reasoning"], str) self.assertIsInstance(response.input_tokens, int) self.assertIsInstance(response.output_tokens, int)