From a6d0eeb74620f05dd4a80526cbc9878518f83697 Mon Sep 17 00:00:00 2001 From: jp Date: Wed, 30 Oct 2024 18:32:10 +0800 Subject: [PATCH] Add gemini Add multi-llm chat interface Add sys_msg for anthropic and gemini when not present --- .github/workflows/main.yml | 1 + defog_utils/utils_llm.py | 58 ++++++++++++++-- defog_utils/utils_multi_llm.py | 28 ++++++++ requirements.txt | 1 + tests/test_utils_llm.py | 120 ++++++++++++++++++++++++++------- tests/test_utils_multi_llm.py | 77 +++++++++++++++++++++ 6 files changed, 256 insertions(+), 29 deletions(-) create mode 100644 defog_utils/utils_multi_llm.py create mode 100644 tests/test_utils_multi_llm.py diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 715b662..b08495c 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -23,4 +23,5 @@ jobs: ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }} OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} TOGETHER_API_KEY: ${{ secrets.TOGETHER_API_KEY }} + GEMINI_API_KEY: ${{ secrets.GEMINI_API_KEY }} diff --git a/defog_utils/utils_llm.py b/defog_utils/utils_llm.py index 1bda4e1..87e0380 100644 --- a/defog_utils/utils_llm.py +++ b/defog_utils/utils_llm.py @@ -1,6 +1,9 @@ -from dataclasses import dataclass +import os import time -from typing import Optional, List, Dict +from dataclasses import dataclass +from typing import Dict, List, Optional + +import google.generativeai as genai from anthropic import Anthropic from openai import OpenAI from together import Together @@ -8,6 +11,7 @@ client_anthropic = Anthropic() client_openai = OpenAI() client_together = Together() +genai.configure(api_key=os.environ["GEMINI_API_KEY"]) @dataclass @@ -36,7 +40,7 @@ def chat_anthropic( sys_msg = messages[0]["content"] messages = messages[1:] else: - sys_msg = None + sys_msg = "" response = client_anthropic.messages.create( system=sys_msg, messages=messages, @@ -97,7 +101,7 @@ def chat_openai( def chat_together( messages: List[Dict[str, str]], - model: str = "meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo", + model: str = "meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo", max_tokens: int = 4096, temperature: float = 0.0, stop: List[str] = [], @@ -130,3 +134,49 @@ def chat_together( response.usage.prompt_tokens, response.usage.completion_tokens, ) + + +def chat_gemini( + messages: List[Dict[str, str]], + model: str = "gemini-1.5-pro", + max_tokens: int = 8192, + temperature: float = 0.0, + stop: List[str] = [], + json_mode: bool = False, + seed: int = 0, +) -> Optional[LLMResponse]: + t = time.time() + generation_config = { + "temperature": temperature, + "max_output_tokens": max_tokens, + "response_mime_type": "application/json" if json_mode else "text/plain", + "stop_sequences": stop, + # "seed": seed, # seed is not supported in the current version + } + if messages[0]["role"] == "system": + system_msg = messages[0]["content"] + messages = messages[1:] + else: + system_msg = None + final_msg = messages[-1]["content"] + messages = messages[:-1] + for msg in messages: + if msg["role"] != "user": + msg["role"] = "model" + client_gemini = genai.GenerativeModel(model, generation_config=generation_config, system_instruction=system_msg) + chat = client_gemini.start_chat( + history=messages, + ) + response = chat.send_message(final_msg) + if len(response.candidates) == 0: + print("Empty response") + return None + if response.candidates[0].finish_reason.value != 1: # 1 is the finish reason for STOP + print("Max tokens reached") + return None + return LLMResponse( + content=response.candidates[0].content.parts[0].text, + time=round(time.time() - t, 3), + input_tokens=response.usage_metadata.prompt_token_count, + output_tokens=response.usage_metadata.candidates_token_count, + ) diff --git a/defog_utils/utils_multi_llm.py b/defog_utils/utils_multi_llm.py new file mode 100644 index 0000000..8770c2c --- /dev/null +++ b/defog_utils/utils_multi_llm.py @@ -0,0 +1,28 @@ +import concurrent +from typing import Callable, Dict + +from .utils_llm import LLMResponse, chat_anthropic, chat_gemini, chat_openai, chat_together + +def map_model_to_chat_fn(model: str) -> Callable: + """ + Returns the appropriate chat function based on the model. + """ + if model.startswith("claude"): + return chat_anthropic + if model.startswith("gemini"): + return chat_gemini + if model.startswith("gpt"): + return chat_openai + if model.startswith("meta-llama") or model.startswith("mistralai") or model.startswith("Qwen"): + return chat_together + raise ValueError(f"Unknown model: {model}") + +def chat(models, messages, max_tokens=8192, temperature=0.0, stop=[], json_mode=False, seed=0) -> Dict[str, LLMResponse]: + """ + Returns the response from the LLM API for each of the models passed in. + Output format is a dictionary keyed by model name. + """ + with concurrent.futures.ThreadPoolExecutor(max_workers=len(models)) as executor: + futures = {executor.submit(map_model_to_chat_fn(model), messages, model, max_tokens, temperature, stop, json_mode, seed): model for model in models} + responses = {futures[future]: future.result() for future in concurrent.futures.as_completed(futures)} + return responses \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index ee803eb..97d8cd5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,5 @@ anthropic==0.37.1 +google-generativeai==0.8.3 numpy openai==1.52.2 psycopg2-binary==2.9.9 diff --git a/tests/test_utils_llm.py b/tests/test_utils_llm.py index ac0c302..e9f2e09 100644 --- a/tests/test_utils_llm.py +++ b/tests/test_utils_llm.py @@ -3,11 +3,13 @@ from ..defog_utils.utils_llm import ( LLMResponse, chat_anthropic, + chat_gemini, chat_openai, chat_together, ) -messages = [ +messages_no_sys = [{"role": "user", "content": "Return a greeting in not more than 2 words\n"}] +messages_sql = [ { "role": "system", "content": "Your task is to generate SQL given a natural language question and schema of the user's database. Do not use aliases. Return only the SQL without ```.", @@ -50,24 +52,72 @@ ] acceptable_sql = [ - "SELECT COUNT(*) FROM orders", - "SELECT COUNT(order_id) FROM orders", + "select count(*) from orders", + "select count(order_id) from orders", + "select count(*) as total_orders from orders", + "select count(order_id) as total_orders from orders", ] -acceptable_sql_from_json = set( - [ - "SELECT COUNT(order_id) as total_orders FROM orders;", - "SELECT COUNT(*) AS total_orders FROM orders;", - "SELECT COUNT(order_id) FROM orders;", - "SELECT COUNT(*) FROM orders;", - ] -) +class TestChatClients(unittest.TestCase): + def check_sql(self, sql: str): + self.assertIn(sql.strip(";\n").lower(), acceptable_sql) -class TestChatClients(unittest.TestCase): - def test_chat_anthropic(self): + def test_chat_anthropic_no_sys(self): + response = chat_anthropic( + messages_no_sys, + model="claude-3-haiku-20240307", + max_tokens=10, + seed=0, + ) + print(response) + self.assertIsInstance(response, LLMResponse) + self.assertIsInstance(response.content, str) + self.assertEqual(response.input_tokens, 18) + self.assertLessEqual(response.output_tokens, 10) + + def test_chat_gemini_no_sys(self): + response = chat_gemini( + messages_no_sys, + model="gemini-1.5-flash", + max_tokens=10, + seed=0, + ) + print(response) + self.assertIsInstance(response, LLMResponse) + self.assertIsInstance(response.content, str) + self.assertEqual(response.input_tokens, 12) + self.assertLessEqual(response.output_tokens, 10) + + def test_chat_openai_no_sys(self): + response = chat_openai( + messages_no_sys, + model="gpt-4o-mini", + max_tokens=10, + seed=0, + ) + print(response) + self.assertIsInstance(response, LLMResponse) + self.assertIsInstance(response.content, str) + self.assertEqual(response.input_tokens, 18) + self.assertLessEqual(response.output_tokens, 10) + + def test_chat_together_no_sys(self): + response = chat_together( + messages_no_sys, + model="meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo", + max_tokens=10, + seed=0, + ) + print(response) + self.assertIsInstance(response, LLMResponse) + self.assertIsInstance(response.content, str) + self.assertEqual(response.input_tokens, 46) # hidden sys prompt added I think + self.assertLessEqual(response.output_tokens, 10) + + def test_chat_anthropic_sql(self): response = chat_anthropic( - messages, + messages_sql, model="claude-3-haiku-20240307", max_tokens=100, stop=[";"], @@ -75,31 +125,39 @@ def test_chat_anthropic(self): ) print(response) self.assertIsInstance(response, LLMResponse) - self.assertIn(response.content, acceptable_sql) + self.check_sql(response.content) self.assertEqual(response.input_tokens, 90) # 90 input tokens - self.assertTrue(response.output_tokens < 10) # output tokens should be < 10 + self.assertTrue(response.output_tokens < 15) # output tokens should be < 15 - def test_chat_openai(self): - response = chat_openai(messages, model="gpt-4o-mini", stop=[";"], seed=0) + def test_chat_openai_sql(self): + response = chat_openai(messages_sql, model="gpt-4o-mini", stop=[";"], seed=0) print(response) self.assertIsInstance(response, LLMResponse) - self.assertIn(response.content, acceptable_sql) + self.check_sql(response.content) self.assertEqual(response.input_tokens, 83) self.assertTrue(response.output_tokens < 10) # output tokens should be < 10 - def test_chat_together(self): + def test_chat_together_sql(self): response = chat_together( - messages, + messages_sql, model="meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo", stop=[";"], seed=0, ) print(response) self.assertIsInstance(response, LLMResponse) - self.assertIn(response.content, acceptable_sql) + self.check_sql(response.content) self.assertEqual(response.input_tokens, 108) self.assertTrue(response.output_tokens < 10) # output tokens should be < 10 + def test_chat_gemini_sql(self): + response = chat_gemini(messages_sql, model="gemini-1.5-flash", stop=[";"], seed=0) + print(response) + self.assertIsInstance(response, LLMResponse) + self.check_sql(response.content) + self.assertEqual(response.input_tokens, 86) + self.assertTrue(response.output_tokens < 10) + def test_chat_json_anthropic(self): response = chat_anthropic( messages_json, @@ -111,7 +169,7 @@ def test_chat_json_anthropic(self): print(response) self.assertIsInstance(response, LLMResponse) resp_dict = json.loads(response.content) - self.assertIn(resp_dict["sql"], acceptable_sql_from_json) + self.check_sql(resp_dict["sql"]) self.assertIsInstance(resp_dict["reasoning"], str) self.assertIsInstance(response.input_tokens, int) self.assertIsInstance(response.output_tokens, int) @@ -123,7 +181,7 @@ def test_chat_json_openai(self): print(response) self.assertIsInstance(response, LLMResponse) resp_dict = json.loads(response.content) - self.assertIn(resp_dict["sql"], acceptable_sql_from_json) + self.check_sql(resp_dict["sql"]) self.assertIsInstance(resp_dict["reasoning"], str) self.assertIsInstance(response.input_tokens, int) self.assertIsInstance(response.output_tokens, int) @@ -139,7 +197,19 @@ def test_chat_json_together(self): self.assertIsInstance(response, LLMResponse) raw_output = response.content resp_dict = json.loads(raw_output) - self.assertIn(resp_dict["sql"], acceptable_sql_from_json) + self.check_sql(resp_dict["sql"]) + self.assertIsInstance(resp_dict["reasoning"], str) + self.assertIsInstance(response.input_tokens, int) + self.assertIsInstance(response.output_tokens, int) + + def test_chat_json_gemini(self): + response = chat_gemini( + messages_json, model="gemini-1.5-flash", seed=0, json_mode=True + ) + print(response) + self.assertIsInstance(response, LLMResponse) + resp_dict = json.loads(response.content) + self.check_sql(resp_dict["sql"]) self.assertIsInstance(resp_dict["reasoning"], str) self.assertIsInstance(response.input_tokens, int) self.assertIsInstance(response.output_tokens, int) diff --git a/tests/test_utils_multi_llm.py b/tests/test_utils_multi_llm.py new file mode 100644 index 0000000..d5c1440 --- /dev/null +++ b/tests/test_utils_multi_llm.py @@ -0,0 +1,77 @@ +import unittest +from ..defog_utils.utils_multi_llm import map_model_to_chat_fn, chat +from ..defog_utils.utils_llm import LLMResponse, chat_anthropic, chat_gemini, chat_openai, chat_together + +messages_sql = [ + { + "role": "system", + "content": "Your task is to generate SQL given a natural language question and schema of the user's database. Do not use aliases. Return only the SQL without ```.", + }, + { + "role": "user", + "content": f"""Question: What is the total number of orders? +Schema: +```sql +CREATE TABLE orders ( + order_id int, + customer_id int, + employee_id int, + order_date date +); +``` +""", + }, +] + +acceptable_sql = [ + "select count(*) from orders", + "select count(order_id) from orders", + "select count(*) as total_orders from orders", + "select count(order_id) as total_orders from orders", +] + +class TestChatClients(unittest.TestCase): + def check_sql(self, sql: str): + self.assertIn(sql.strip(";\n").lower(), acceptable_sql) + + def test_map_model_to_chat_fn(self): + self.assertEqual(map_model_to_chat_fn("claude-3-5-sonnet-20241022"), chat_anthropic) + self.assertEqual(map_model_to_chat_fn("gemini-1.5-pro"), chat_gemini) + self.assertEqual(map_model_to_chat_fn("gpt-4o"), chat_openai) + self.assertEqual(map_model_to_chat_fn("mistralai/Mistral-7B-Instruct-v0.3"), chat_together) + self.assertEqual(map_model_to_chat_fn("meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo"), chat_together) + self.assertEqual(map_model_to_chat_fn("Qwen/Qwen2.5-72B-Instruct-Turbo"), chat_together) + with self.assertRaises(ValueError): + map_model_to_chat_fn("unknown-model") + + def test_simple_chat(self): + models = ["claude-3-haiku-20240307", "gemini-1.5-flash-002", "gpt-4o-mini", "meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo"] + messages = [{"role": "user", "content": "Return a greeting in not more than 2 words\n"}] + responses = chat(models, messages, max_tokens=20, temperature=0.0, stop=[";"], json_mode=False, seed=0) + self.assertIsInstance(responses, dict) + for model in models: + self.assertIn(model, responses) + response = responses[model] + print(model, response) + self.assertIsInstance(response, LLMResponse) + self.assertIsInstance(response.content, str) + self.assertIsInstance(response.time, float) + self.assertLess(response.input_tokens, 50) # higher as default system prompt is added in together's API when none provided + self.assertLess(response.output_tokens, 20) + + def test_sql_chat(self): + models = ["claude-3-haiku-20240307", "gemini-1.5-flash-002", "gpt-4o-mini", "meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo"] + responses = chat(models, messages_sql, max_tokens=20, temperature=0.0, stop=[";"], json_mode=False, seed=0) + self.assertIsInstance(responses, dict) + for model in models: + self.assertIn(model, responses) + response = responses[model] + print(model, response) + self.assertIsInstance(response, LLMResponse) + self.assertIsInstance(response.content, str) + self.assertIsInstance(response.time, float) + self.assertLess(response.input_tokens, 110) + self.assertLess(response.output_tokens, 20) + +if __name__ == "__main__": + unittest.main() \ No newline at end of file