-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- add wrappers for LLM API querying that standardizes the response format returned with a few of the general information that we use (content, latency, input/output tokens) - updated paths to relative paths for tests - update requirements.txt
- Loading branch information
1 parent
d0f0a2b
commit c0743db
Showing
7 changed files
with
292 additions
and
19 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,132 @@ | ||
from dataclasses import dataclass | ||
import time | ||
from typing import Optional, List, Dict | ||
from anthropic import Anthropic | ||
from openai import OpenAI | ||
from together import Together | ||
|
||
client_anthropic = Anthropic() | ||
client_openai = OpenAI() | ||
client_together = Together() | ||
|
||
|
||
@dataclass | ||
class LLMResponse: | ||
content: str | ||
time: float | ||
input_tokens: int | ||
output_tokens: int | ||
|
||
|
||
def chat_anthropic( | ||
messages: List[Dict[str, str]], | ||
model: str = "claude-3-5-sonnet-20241022", | ||
max_tokens: int = 8192, | ||
temperature: float = 0.0, | ||
stop: List[str] = [], | ||
json_mode: bool = False, | ||
seed: int = 0, | ||
) -> Optional[LLMResponse]: | ||
""" | ||
Returns the response from the Anthropic API, the time taken to generate the response, the number of input tokens used, and the number of output tokens used. | ||
Note that anthropic doesn't have explicit json mode api constraints, nor does it have a seed parameter. | ||
""" | ||
t = time.time() | ||
if len(messages) >= 1 and messages[0].get("role") == "system": | ||
sys_msg = messages[0]["content"] | ||
messages = messages[1:] | ||
else: | ||
sys_msg = None | ||
response = client_anthropic.messages.create( | ||
system=sys_msg, | ||
messages=messages, | ||
model=model, | ||
max_tokens=max_tokens, | ||
temperature=temperature, | ||
stop_sequences=stop, | ||
) | ||
if response.stop_reason == "max_tokens": | ||
print("Max tokens reached") | ||
return None | ||
if len(response.content) == 0: | ||
print("Empty response") | ||
return None | ||
return LLMResponse( | ||
response.content[0].text, | ||
round(time.time() - t, 3), | ||
response.usage.input_tokens, | ||
response.usage.output_tokens, | ||
) | ||
|
||
|
||
def chat_openai( | ||
messages: List[Dict[str, str]], | ||
model: str = "gpt-4o", | ||
max_tokens: int = 16384, | ||
temperature: float = 0.0, | ||
stop: List[str] = [], | ||
json_mode: bool = False, | ||
seed: int = 0, | ||
) -> Optional[LLMResponse]: | ||
""" | ||
Returns the response from the OpenAI API, the time taken to generate the response, the number of input tokens used, and the number of output tokens used. | ||
""" | ||
t = time.time() | ||
response = client_openai.chat.completions.create( | ||
messages=messages, | ||
model=model, | ||
max_tokens=max_tokens, | ||
temperature=temperature, | ||
stop=stop, | ||
response_format={"type": "json_object"} if json_mode else None, | ||
seed=seed, | ||
) | ||
if response.choices[0].finish_reason == "length": | ||
print("Max tokens reached") | ||
return None | ||
if len(response.choices) == 0: | ||
print("Empty response") | ||
return None | ||
return LLMResponse( | ||
response.choices[0].message.content, | ||
round(time.time() - t, 3), | ||
response.usage.prompt_tokens, | ||
response.usage.completion_tokens, | ||
) | ||
|
||
|
||
def chat_together( | ||
messages: List[Dict[str, str]], | ||
model: str = "meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo", | ||
max_tokens: int = 4096, | ||
temperature: float = 0.0, | ||
stop: List[str] = [], | ||
json_mode: bool = False, | ||
seed: int = 0, | ||
) -> Optional[LLMResponse]: | ||
""" | ||
Returns the response from the Together API, the time taken to generate the response, the number of input tokens used, and the number of output tokens used. | ||
Together's max_tokens refers to the maximum completion tokens, not the maximum total tokens, hence requires calculating 8192 - input_tokens. | ||
Together doesn't have explicit json mode api constraints. | ||
""" | ||
t = time.time() | ||
response = client_together.chat.completions.create( | ||
messages=messages, | ||
model=model, | ||
max_tokens=max_tokens, | ||
temperature=temperature, | ||
stop=stop, | ||
seed=seed, | ||
) | ||
if response.choices[0].finish_reason == "length": | ||
print("Max tokens reached") | ||
return None | ||
if len(response.choices) == 0: | ||
print("Empty response") | ||
return None | ||
return LLMResponse( | ||
response.choices[0].message.content, | ||
round(time.time() - t, 3), | ||
response.usage.prompt_tokens, | ||
response.usage.completion_tokens, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,8 @@ | ||
anthropic==0.37.1 | ||
numpy | ||
openai==1.52.2 | ||
psycopg2-binary==2.9.9 | ||
sqlglot==25.8.1 | ||
sqlglotrs==0.2.8 | ||
sqlparse==0.5.1 | ||
sqlglot==25.27.0 | ||
sqlglotrs==0.2.12 | ||
sqlparse==0.5.1 | ||
together==1.3.3 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,145 @@ | ||
import json | ||
import unittest | ||
from ..defog_utils.utils_llm import ( | ||
LLMResponse, | ||
chat_anthropic, | ||
chat_openai, | ||
chat_together, | ||
) | ||
|
||
messages = [ | ||
{ | ||
"role": "system", | ||
"content": "Your task is to generate SQL given a natural language question and schema of the user's database. Do not use aliases. Return only the SQL without ```.", | ||
}, | ||
{ | ||
"role": "user", | ||
"content": f"""Question: What is the total number of orders? | ||
Schema: | ||
```sql | ||
CREATE TABLE orders ( | ||
order_id int, | ||
customer_id int, | ||
employee_id int, | ||
order_date date | ||
); | ||
``` | ||
""", | ||
}, | ||
] | ||
|
||
messages_json = [ | ||
{ | ||
"role": "system", | ||
"content": 'Your task is to generate SQL given a natural language question and schema of the user\'s database. Return your answer only as a JSON object with the reasoning in the \'reasoning\' field and SQL in the \'sql\' field, without ```. For example, {"sql": "...", "reasoning": "..."}', | ||
}, | ||
{ | ||
"role": "user", | ||
"content": f"""Question: What is the total number of orders? | ||
Schema: | ||
```sql | ||
CREATE TABLE orders ( | ||
order_id int, | ||
customer_id int, | ||
employee_id int, | ||
order_date date | ||
); | ||
``` | ||
""", | ||
}, | ||
] | ||
|
||
acceptable_sql = [ | ||
"SELECT COUNT(*) FROM orders", | ||
"SELECT COUNT(order_id) FROM orders", | ||
] | ||
|
||
acceptable_sql_from_json = set( | ||
[ | ||
"SELECT COUNT(order_id) as total_orders FROM orders;", | ||
"SELECT COUNT(*) AS total_orders FROM orders;", | ||
"SELECT COUNT(order_id) FROM orders;", | ||
"SELECT COUNT(*) FROM orders;", | ||
] | ||
) | ||
|
||
|
||
class TestChatClients(unittest.TestCase): | ||
def test_chat_anthropic(self): | ||
response = chat_anthropic( | ||
messages, | ||
model="claude-3-haiku-20240307", | ||
max_tokens=100, | ||
stop=[";"], | ||
seed=0, | ||
) | ||
print(response) | ||
self.assertIsInstance(response, LLMResponse) | ||
self.assertIn(response.content, acceptable_sql) | ||
self.assertEqual(response.input_tokens, 90) # 90 input tokens | ||
self.assertTrue(response.output_tokens < 10) # output tokens should be < 10 | ||
|
||
def test_chat_openai(self): | ||
response = chat_openai(messages, model="gpt-4o-mini", stop=[";"], seed=0) | ||
print(response) | ||
self.assertIsInstance(response, LLMResponse) | ||
self.assertIn(response.content, acceptable_sql) | ||
self.assertEqual(response.input_tokens, 83) | ||
self.assertTrue(response.output_tokens < 10) # output tokens should be < 10 | ||
|
||
def test_chat_together(self): | ||
response = chat_together( | ||
messages, | ||
model="meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo", | ||
stop=[";"], | ||
seed=0, | ||
) | ||
print(response) | ||
self.assertIsInstance(response, LLMResponse) | ||
self.assertIn(response.content, acceptable_sql) | ||
self.assertEqual(response.input_tokens, 108) | ||
self.assertTrue(response.output_tokens < 10) # output tokens should be < 10 | ||
|
||
def test_chat_json_anthropic(self): | ||
response = chat_anthropic( | ||
messages_json, | ||
model="claude-3-haiku-20240307", | ||
max_tokens=100, | ||
seed=0, | ||
json_mode=True, | ||
) | ||
print(response) | ||
self.assertIsInstance(response, LLMResponse) | ||
resp_dict = json.loads(response.content) | ||
self.assertIn(resp_dict["sql"], acceptable_sql_from_json) | ||
self.assertIsInstance(resp_dict["reasoning"], str) | ||
self.assertIsInstance(response.input_tokens, int) | ||
self.assertIsInstance(response.output_tokens, int) | ||
|
||
def test_chat_json_openai(self): | ||
response = chat_openai( | ||
messages_json, model="gpt-4o-mini", seed=0, json_mode=True | ||
) | ||
print(response) | ||
self.assertIsInstance(response, LLMResponse) | ||
resp_dict = json.loads(response.content) | ||
self.assertIn(resp_dict["sql"], acceptable_sql_from_json) | ||
self.assertIsInstance(resp_dict["reasoning"], str) | ||
self.assertIsInstance(response.input_tokens, int) | ||
self.assertIsInstance(response.output_tokens, int) | ||
|
||
def test_chat_json_together(self): | ||
response = chat_together( | ||
messages_json, | ||
model="meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo", | ||
seed=0, | ||
json_mode=True, | ||
) | ||
print(response) | ||
self.assertIsInstance(response, LLMResponse) | ||
raw_output = response.content | ||
resp_dict = json.loads(raw_output) | ||
self.assertIn(resp_dict["sql"], acceptable_sql_from_json) | ||
self.assertIsInstance(resp_dict["reasoning"], str) | ||
self.assertIsInstance(response.input_tokens, int) | ||
self.assertIsInstance(response.output_tokens, int) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters