Skip to content

Commit

Permalink
Add LLM API querying utils
Browse files Browse the repository at this point in the history
- add wrappers for LLM API querying that standardizes the response format returned with a few of the general information that we use (content, latency, input/output tokens)
- updated paths to relative paths for tests
- update requirements.txt
  • Loading branch information
wongjingping committed Oct 25, 2024
1 parent d0f0a2b commit df09ef8
Show file tree
Hide file tree
Showing 7 changed files with 288 additions and 19 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
Internal utilities library for Defog. These utilities comprise:
* sql/schema/instruction feature extraction
* database connectors and validators
* basic llm querying across various API's (openai, anthropic, together)

These utilities should be purely stateless and have minimal library dependencies.

Expand Down
130 changes: 130 additions & 0 deletions defog_utils/utils_llm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
from dataclasses import dataclass
import time
from typing import Optional, List, Dict
from anthropic import Anthropic
from openai import OpenAI
from together import Together

client_anthropic = Anthropic()
client_openai = OpenAI()
client_together = Together()

@dataclass
class LLMResponse:
content: str
time: float
input_tokens: int
output_tokens: int

def chat_anthropic(
messages: List[Dict[str, str]],
model: str = "claude-3-5-sonnet-20241022",
max_tokens: int = 8192,
temperature: float = 0.0,
stop: List[str] = [],
json_mode: bool = False,
seed: int = 0,
) -> Optional[LLMResponse]:
"""
Returns the response from the Anthropic API, the time taken to generate the response, the number of input tokens used, and the number of output tokens used.
Note that anthropic doesn't have explicit json mode api constraints, nor does it have a seed parameter.
"""
t = time.time()
if len(messages) >= 1 and messages[0].get("role") == "system":
sys_msg = messages[0]["content"]
messages = messages[1:]
else:
sys_msg = None
response = client_anthropic.messages.create(
system=sys_msg,
messages=messages,
model=model,
max_tokens=max_tokens,
temperature=temperature,
stop_sequences=stop,
)
if response.stop_reason == "max_tokens":
print("Max tokens reached")
return None
if len(response.content) == 0:
print("Empty response")
return None
return LLMResponse(
response.content[0].text,
round(time.time() - t, 3),
response.usage.input_tokens,
response.usage.output_tokens,
)


def chat_openai(
messages: List[Dict[str, str]],
model: str = "gpt-4o",
max_tokens: int = 16384,
temperature: float = 0.0,
stop: List[str] = [],
json_mode: bool = False,
seed: int = 0,
) -> Optional[LLMResponse]:
"""
Returns the response from the OpenAI API, the time taken to generate the response, the number of input tokens used, and the number of output tokens used.
"""
t = time.time()
response = client_openai.chat.completions.create(
messages=messages,
model=model,
max_tokens=max_tokens,
temperature=temperature,
stop=stop,
response_format={"type": "json_object"} if json_mode else None,
seed=seed,
)
if response.choices[0].finish_reason == "length":
print("Max tokens reached")
return None
if len(response.choices) == 0:
print("Empty response")
return None
return LLMResponse(
response.choices[0].message.content,
round(time.time() - t, 3),
response.usage.prompt_tokens,
response.usage.completion_tokens,
)


def chat_together(
messages: List[Dict[str, str]],
model: str = "meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo",
max_tokens: int = 4096,
temperature: float = 0.0,
stop: List[str] = [],
json_mode: bool = False,
seed: int = 0,
) -> Optional[LLMResponse]:
"""
Returns the response from the Together API, the time taken to generate the response, the number of input tokens used, and the number of output tokens used.
Together's max_tokens refers to the maximum completion tokens, not the maximum total tokens, hence requires calculating 8192 - input_tokens.
Together doesn't have explicit json mode api constraints.
"""
t = time.time()
response = client_together.chat.completions.create(
messages=messages,
model=model,
max_tokens=max_tokens,
temperature=temperature,
stop=stop,
seed=seed,
)
if response.choices[0].finish_reason == "length":
print("Max tokens reached")
return None
if len(response.choices) == 0:
print("Empty response")
return None
return LLMResponse(
response.choices[0].message.content,
round(time.time() - t, 3),
response.usage.prompt_tokens,
response.usage.completion_tokens,
)
9 changes: 6 additions & 3 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
anthropic==0.37.1
numpy
openai==1.52.2
psycopg2-binary==2.9.9
sqlglot==25.8.1
sqlglotrs==0.2.8
sqlparse==0.5.1
sqlglot==25.27.0
sqlglotrs==0.2.12
sqlparse==0.5.1
together==1.3.3
2 changes: 1 addition & 1 deletion tests/test_utils_db.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import unittest
from defog_utils.defog_utils.utils_db import (
from ..defog_utils.utils_db import (
convert_data_type_postgres,
fix_md,
generate_aliases,
Expand Down
4 changes: 2 additions & 2 deletions tests/test_utils_instructions.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import unittest
from defog_utils.defog_utils.constants import idk_strings
from defog_utils.defog_utils.utils_instructions import (
from ..defog_utils.constants import idk_strings
from ..defog_utils.utils_instructions import (
InstructionFeatures,
get_instruction_features,
)
Expand Down
143 changes: 143 additions & 0 deletions tests/test_utils_llm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
import json
import unittest
from ..defog_utils.utils_llm import (
LLMResponse,
chat_anthropic,
chat_openai,
chat_together,
)

messages = [
{
"role": "system",
"content": "Your task is to generate SQL given a natural language question and schema of the user's database. Do not use aliases. Return only the SQL without ```.",
},
{
"role": "user",
"content": f"""Question: What is the total number of orders?
Schema:
```sql
CREATE TABLE orders (
order_id int,
customer_id int,
employee_id int,
order_date date
);
```
""",
},
]

messages_json = [
{
"role": "system",
"content": "Your task is to generate SQL given a natural language question and schema of the user's database. Return your answer only as a JSON object with the reasoning in the 'reasoning' field and SQL in the 'sql' field, without ```. For example, {\"sql\": \"...\", \"reasoning\": \"...\"}",
},
{
"role": "user",
"content": f"""Question: What is the total number of orders?
Schema:
```sql
CREATE TABLE orders (
order_id int,
customer_id int,
employee_id int,
order_date date
);
```
""",
},
]

acceptable_sql = [
"SELECT COUNT(*) FROM orders",
"SELECT COUNT(order_id) FROM orders",
]

acceptable_sql_from_json = set([
"SELECT COUNT(order_id) as total_orders FROM orders;",
"SELECT COUNT(*) AS total_orders FROM orders;",
"SELECT COUNT(order_id) FROM orders;",
"SELECT COUNT(*) FROM orders;",
])


class TestChatClients(unittest.TestCase):
def test_chat_anthropic(self):
response = chat_anthropic(
messages,
model="claude-3-haiku-20240307",
max_tokens=100,
stop=[";"],
seed=0,
)
print(response)
self.assertIsInstance(response, LLMResponse)
self.assertIn(response.content, acceptable_sql)
self.assertEqual(response.input_tokens, 90) # 90 input tokens
self.assertTrue(response.output_tokens < 10) # output tokens should be < 10

def test_chat_openai(self):
response = chat_openai(messages, model="gpt-4o-mini", stop=[";"], seed=0)
print(response)
self.assertIsInstance(response, LLMResponse)
self.assertIn(response.content, acceptable_sql)
self.assertEqual(response.input_tokens, 83)
self.assertTrue(response.output_tokens < 10) # output tokens should be < 10

def test_chat_together(self):
response = chat_together(
messages,
model="meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo",
stop=[";"],
seed=0,
)
print(response)
self.assertIsInstance(response, LLMResponse)
self.assertIn(response.content, acceptable_sql)
self.assertEqual(response.input_tokens, 108)
self.assertTrue(response.output_tokens < 10) # output tokens should be < 10

def test_chat_json_anthropic(self):
response = chat_anthropic(
messages_json,
model="claude-3-haiku-20240307",
max_tokens=100,
seed=0,
json_mode=True,
)
print(response)
self.assertIsInstance(response, LLMResponse)
resp_dict = json.loads(response.content)
self.assertIn(resp_dict["sql"], acceptable_sql_from_json)
self.assertIsInstance(resp_dict["reasoning"], str)
self.assertIsInstance(response.input_tokens, int)
self.assertIsInstance(response.output_tokens, int)

def test_chat_json_openai(self):
response = chat_openai(
messages_json, model="gpt-4o-mini", seed=0, json_mode=True
)
print(response)
self.assertIsInstance(response, LLMResponse)
resp_dict = json.loads(response.content)
self.assertIn(resp_dict["sql"], acceptable_sql_from_json)
self.assertIsInstance(resp_dict["reasoning"], str)
self.assertIsInstance(response.input_tokens, int)
self.assertIsInstance(response.output_tokens, int)

def test_chat_json_together(self):
response = chat_together(
messages_json,
model="meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo",
seed=0,
json_mode=True,
)
print(response)
self.assertIsInstance(response, LLMResponse)
raw_output = response.content
resp_dict = json.loads(raw_output)
self.assertIn(resp_dict["sql"], acceptable_sql_from_json)
self.assertIsInstance(resp_dict["reasoning"], str)
self.assertIsInstance(response.input_tokens, int)
self.assertIsInstance(response.output_tokens, int)
18 changes: 5 additions & 13 deletions tests/test_utils_sql.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import unittest
from defog_utils.defog_utils.utils_sql import (
from ..defog_utils.utils_sql import (
add_space_padding,
fix_comma,
get_schema_features,
Expand Down Expand Up @@ -345,19 +345,13 @@ def test_date_time_type_conversion(self):
)
self.assertTrue(features.date_time_type_conversion)
sql = "SELECT CONVERT(DATE, date_str_column) FROM table"
features = get_sql_features(
sql, self.md_cols, self.md_tables, dialect="tsql"
)
features = get_sql_features(sql, self.md_cols, self.md_tables, dialect="tsql")
self.assertTrue(features.date_time_type_conversion)
sql = "SELECT CONVERT(INT, col) FROM table"
features = get_sql_features(
sql, self.md_cols, self.md_tables, dialect="tsql"
)
features = get_sql_features(sql, self.md_cols, self.md_tables, dialect="tsql")
self.assertFalse(features.date_time_type_conversion)
sql = "SELECT DATEFROMPARTS(year_column, month_column, day_column) FROM table"
features = get_sql_features(
sql, self.md_cols, self.md_tables, dialect="tsql"
)
features = get_sql_features(sql, self.md_cols, self.md_tables, dialect="tsql")
self.assertTrue(features.date_time_type_conversion)

def test_date_time_format(self):
Expand All @@ -374,9 +368,7 @@ def test_date_time_format(self):
self.assertFalse(features.date_time_format)
self.assertFalse(features.strftime)
sql = "SELECT FORMAT(column, 'YYYY-MM-DD') FROM table"
features = get_sql_features(
sql, self.md_cols, self.md_tables, dialect="tsql"
)
features = get_sql_features(sql, self.md_cols, self.md_tables, dialect="tsql")
self.assertTrue(features.date_time_format)

def test_generate_timeseries(self):
Expand Down

0 comments on commit df09ef8

Please sign in to comment.