From df09ef813093474f3e7656622a9e09864f731452 Mon Sep 17 00:00:00 2001 From: jp Date: Fri, 25 Oct 2024 13:04:18 +0800 Subject: [PATCH] Add LLM API querying utils - add wrappers for LLM API querying that standardizes the response format returned with a few of the general information that we use (content, latency, input/output tokens) - updated paths to relative paths for tests - update requirements.txt --- README.md | 1 + defog_utils/utils_llm.py | 130 ++++++++++++++++++++++++++++ requirements.txt | 9 +- tests/test_utils_db.py | 2 +- tests/test_utils_instructions.py | 4 +- tests/test_utils_llm.py | 143 +++++++++++++++++++++++++++++++ tests/test_utils_sql.py | 18 ++-- 7 files changed, 288 insertions(+), 19 deletions(-) create mode 100644 defog_utils/utils_llm.py create mode 100644 tests/test_utils_llm.py diff --git a/README.md b/README.md index 96cacea..cc075fc 100644 --- a/README.md +++ b/README.md @@ -3,6 +3,7 @@ Internal utilities library for Defog. These utilities comprise: * sql/schema/instruction feature extraction * database connectors and validators +* basic llm querying across various API's (openai, anthropic, together) These utilities should be purely stateless and have minimal library dependencies. diff --git a/defog_utils/utils_llm.py b/defog_utils/utils_llm.py new file mode 100644 index 0000000..290903d --- /dev/null +++ b/defog_utils/utils_llm.py @@ -0,0 +1,130 @@ +from dataclasses import dataclass +import time +from typing import Optional, List, Dict +from anthropic import Anthropic +from openai import OpenAI +from together import Together + +client_anthropic = Anthropic() +client_openai = OpenAI() +client_together = Together() + +@dataclass +class LLMResponse: + content: str + time: float + input_tokens: int + output_tokens: int + +def chat_anthropic( + messages: List[Dict[str, str]], + model: str = "claude-3-5-sonnet-20241022", + max_tokens: int = 8192, + temperature: float = 0.0, + stop: List[str] = [], + json_mode: bool = False, + seed: int = 0, +) -> Optional[LLMResponse]: + """ + Returns the response from the Anthropic API, the time taken to generate the response, the number of input tokens used, and the number of output tokens used. + Note that anthropic doesn't have explicit json mode api constraints, nor does it have a seed parameter. + """ + t = time.time() + if len(messages) >= 1 and messages[0].get("role") == "system": + sys_msg = messages[0]["content"] + messages = messages[1:] + else: + sys_msg = None + response = client_anthropic.messages.create( + system=sys_msg, + messages=messages, + model=model, + max_tokens=max_tokens, + temperature=temperature, + stop_sequences=stop, + ) + if response.stop_reason == "max_tokens": + print("Max tokens reached") + return None + if len(response.content) == 0: + print("Empty response") + return None + return LLMResponse( + response.content[0].text, + round(time.time() - t, 3), + response.usage.input_tokens, + response.usage.output_tokens, + ) + + +def chat_openai( + messages: List[Dict[str, str]], + model: str = "gpt-4o", + max_tokens: int = 16384, + temperature: float = 0.0, + stop: List[str] = [], + json_mode: bool = False, + seed: int = 0, +) -> Optional[LLMResponse]: + """ + Returns the response from the OpenAI API, the time taken to generate the response, the number of input tokens used, and the number of output tokens used. + """ + t = time.time() + response = client_openai.chat.completions.create( + messages=messages, + model=model, + max_tokens=max_tokens, + temperature=temperature, + stop=stop, + response_format={"type": "json_object"} if json_mode else None, + seed=seed, + ) + if response.choices[0].finish_reason == "length": + print("Max tokens reached") + return None + if len(response.choices) == 0: + print("Empty response") + return None + return LLMResponse( + response.choices[0].message.content, + round(time.time() - t, 3), + response.usage.prompt_tokens, + response.usage.completion_tokens, + ) + + +def chat_together( + messages: List[Dict[str, str]], + model: str = "meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo", + max_tokens: int = 4096, + temperature: float = 0.0, + stop: List[str] = [], + json_mode: bool = False, + seed: int = 0, +) -> Optional[LLMResponse]: + """ + Returns the response from the Together API, the time taken to generate the response, the number of input tokens used, and the number of output tokens used. + Together's max_tokens refers to the maximum completion tokens, not the maximum total tokens, hence requires calculating 8192 - input_tokens. + Together doesn't have explicit json mode api constraints. + """ + t = time.time() + response = client_together.chat.completions.create( + messages=messages, + model=model, + max_tokens=max_tokens, + temperature=temperature, + stop=stop, + seed=seed, + ) + if response.choices[0].finish_reason == "length": + print("Max tokens reached") + return None + if len(response.choices) == 0: + print("Empty response") + return None + return LLMResponse( + response.choices[0].message.content, + round(time.time() - t, 3), + response.usage.prompt_tokens, + response.usage.completion_tokens, + ) diff --git a/requirements.txt b/requirements.txt index 9881e74..ee803eb 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,8 @@ +anthropic==0.37.1 numpy +openai==1.52.2 psycopg2-binary==2.9.9 -sqlglot==25.8.1 -sqlglotrs==0.2.8 -sqlparse==0.5.1 \ No newline at end of file +sqlglot==25.27.0 +sqlglotrs==0.2.12 +sqlparse==0.5.1 +together==1.3.3 \ No newline at end of file diff --git a/tests/test_utils_db.py b/tests/test_utils_db.py index 7de2c0b..04eae0a 100644 --- a/tests/test_utils_db.py +++ b/tests/test_utils_db.py @@ -1,5 +1,5 @@ import unittest -from defog_utils.defog_utils.utils_db import ( +from ..defog_utils.utils_db import ( convert_data_type_postgres, fix_md, generate_aliases, diff --git a/tests/test_utils_instructions.py b/tests/test_utils_instructions.py index dd9d36b..a3300b9 100644 --- a/tests/test_utils_instructions.py +++ b/tests/test_utils_instructions.py @@ -1,6 +1,6 @@ import unittest -from defog_utils.defog_utils.constants import idk_strings -from defog_utils.defog_utils.utils_instructions import ( +from ..defog_utils.constants import idk_strings +from ..defog_utils.utils_instructions import ( InstructionFeatures, get_instruction_features, ) diff --git a/tests/test_utils_llm.py b/tests/test_utils_llm.py new file mode 100644 index 0000000..ecbc62e --- /dev/null +++ b/tests/test_utils_llm.py @@ -0,0 +1,143 @@ +import json +import unittest +from ..defog_utils.utils_llm import ( + LLMResponse, + chat_anthropic, + chat_openai, + chat_together, +) + +messages = [ + { + "role": "system", + "content": "Your task is to generate SQL given a natural language question and schema of the user's database. Do not use aliases. Return only the SQL without ```.", + }, + { + "role": "user", + "content": f"""Question: What is the total number of orders? +Schema: +```sql +CREATE TABLE orders ( + order_id int, + customer_id int, + employee_id int, + order_date date +); +``` +""", + }, +] + +messages_json = [ + { + "role": "system", + "content": "Your task is to generate SQL given a natural language question and schema of the user's database. Return your answer only as a JSON object with the reasoning in the 'reasoning' field and SQL in the 'sql' field, without ```. For example, {\"sql\": \"...\", \"reasoning\": \"...\"}", + }, + { + "role": "user", + "content": f"""Question: What is the total number of orders? +Schema: +```sql +CREATE TABLE orders ( + order_id int, + customer_id int, + employee_id int, + order_date date +); +``` +""", + }, +] + +acceptable_sql = [ + "SELECT COUNT(*) FROM orders", + "SELECT COUNT(order_id) FROM orders", +] + +acceptable_sql_from_json = set([ + "SELECT COUNT(order_id) as total_orders FROM orders;", + "SELECT COUNT(*) AS total_orders FROM orders;", + "SELECT COUNT(order_id) FROM orders;", + "SELECT COUNT(*) FROM orders;", +]) + + +class TestChatClients(unittest.TestCase): + def test_chat_anthropic(self): + response = chat_anthropic( + messages, + model="claude-3-haiku-20240307", + max_tokens=100, + stop=[";"], + seed=0, + ) + print(response) + self.assertIsInstance(response, LLMResponse) + self.assertIn(response.content, acceptable_sql) + self.assertEqual(response.input_tokens, 90) # 90 input tokens + self.assertTrue(response.output_tokens < 10) # output tokens should be < 10 + + def test_chat_openai(self): + response = chat_openai(messages, model="gpt-4o-mini", stop=[";"], seed=0) + print(response) + self.assertIsInstance(response, LLMResponse) + self.assertIn(response.content, acceptable_sql) + self.assertEqual(response.input_tokens, 83) + self.assertTrue(response.output_tokens < 10) # output tokens should be < 10 + + def test_chat_together(self): + response = chat_together( + messages, + model="meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo", + stop=[";"], + seed=0, + ) + print(response) + self.assertIsInstance(response, LLMResponse) + self.assertIn(response.content, acceptable_sql) + self.assertEqual(response.input_tokens, 108) + self.assertTrue(response.output_tokens < 10) # output tokens should be < 10 + + def test_chat_json_anthropic(self): + response = chat_anthropic( + messages_json, + model="claude-3-haiku-20240307", + max_tokens=100, + seed=0, + json_mode=True, + ) + print(response) + self.assertIsInstance(response, LLMResponse) + resp_dict = json.loads(response.content) + self.assertIn(resp_dict["sql"], acceptable_sql_from_json) + self.assertIsInstance(resp_dict["reasoning"], str) + self.assertIsInstance(response.input_tokens, int) + self.assertIsInstance(response.output_tokens, int) + + def test_chat_json_openai(self): + response = chat_openai( + messages_json, model="gpt-4o-mini", seed=0, json_mode=True + ) + print(response) + self.assertIsInstance(response, LLMResponse) + resp_dict = json.loads(response.content) + self.assertIn(resp_dict["sql"], acceptable_sql_from_json) + self.assertIsInstance(resp_dict["reasoning"], str) + self.assertIsInstance(response.input_tokens, int) + self.assertIsInstance(response.output_tokens, int) + + def test_chat_json_together(self): + response = chat_together( + messages_json, + model="meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo", + seed=0, + json_mode=True, + ) + print(response) + self.assertIsInstance(response, LLMResponse) + raw_output = response.content + resp_dict = json.loads(raw_output) + self.assertIn(resp_dict["sql"], acceptable_sql_from_json) + self.assertIsInstance(resp_dict["reasoning"], str) + self.assertIsInstance(response.input_tokens, int) + self.assertIsInstance(response.output_tokens, int) diff --git a/tests/test_utils_sql.py b/tests/test_utils_sql.py index f31d6bb..c8414b0 100644 --- a/tests/test_utils_sql.py +++ b/tests/test_utils_sql.py @@ -1,5 +1,5 @@ import unittest -from defog_utils.defog_utils.utils_sql import ( +from ..defog_utils.utils_sql import ( add_space_padding, fix_comma, get_schema_features, @@ -345,19 +345,13 @@ def test_date_time_type_conversion(self): ) self.assertTrue(features.date_time_type_conversion) sql = "SELECT CONVERT(DATE, date_str_column) FROM table" - features = get_sql_features( - sql, self.md_cols, self.md_tables, dialect="tsql" - ) + features = get_sql_features(sql, self.md_cols, self.md_tables, dialect="tsql") self.assertTrue(features.date_time_type_conversion) sql = "SELECT CONVERT(INT, col) FROM table" - features = get_sql_features( - sql, self.md_cols, self.md_tables, dialect="tsql" - ) + features = get_sql_features(sql, self.md_cols, self.md_tables, dialect="tsql") self.assertFalse(features.date_time_type_conversion) sql = "SELECT DATEFROMPARTS(year_column, month_column, day_column) FROM table" - features = get_sql_features( - sql, self.md_cols, self.md_tables, dialect="tsql" - ) + features = get_sql_features(sql, self.md_cols, self.md_tables, dialect="tsql") self.assertTrue(features.date_time_type_conversion) def test_date_time_format(self): @@ -374,9 +368,7 @@ def test_date_time_format(self): self.assertFalse(features.date_time_format) self.assertFalse(features.strftime) sql = "SELECT FORMAT(column, 'YYYY-MM-DD') FROM table" - features = get_sql_features( - sql, self.md_cols, self.md_tables, dialect="tsql" - ) + features = get_sql_features(sql, self.md_cols, self.md_tables, dialect="tsql") self.assertTrue(features.date_time_format) def test_generate_timeseries(self):