-
Notifications
You must be signed in to change notification settings - Fork 1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Local Inference script for openfunctions v2 (#242)
* Local inference for open-functions-v2. * Fix: Error in how javascript output was parsed. --------- Co-authored-by: Charlie Cheng-Jie Ji <[email protected]>
- Loading branch information
1 parent
0a1bc0a
commit 381c80b
Showing
4 changed files
with
150 additions
and
9 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,111 @@ | ||
import json | ||
import torch | ||
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline | ||
from openfunctions_utils import strip_function_calls, parse_function_call | ||
|
||
def get_prompt(user_query: str, functions: list = []) -> str: | ||
""" | ||
Generates a conversation prompt based on the user's query and a list of functions. | ||
Parameters: | ||
- user_query (str): The user's query. | ||
- functions (list): A list of functions to include in the prompt. | ||
Returns: | ||
- str: The formatted conversation prompt. | ||
""" | ||
system = "You are an AI programming assistant, utilizing the Gorilla LLM model, developed by Gorilla LLM, and you only answer questions related to computer science. For politically sensitive questions, security and privacy issues, and other non-computer science questions, you will refuse to answer." | ||
if len(functions) == 0: | ||
return f"{system}\n### Instruction: <<question>> {user_query}\n### Response: " | ||
functions_string = json.dumps(functions) | ||
return f"{system}\n### Instruction: <<function>>{functions_string}\n<<question>>{user_query}\n### Response: " | ||
|
||
|
||
def format_response(response: str): | ||
""" | ||
Formats the response from the OpenFunctions model. | ||
Parameters: | ||
- response (str): The response generated by the LLM. | ||
Returns: | ||
- str: The formatted response. | ||
- dict: The function call(s) extracted from the response. | ||
""" | ||
function_call_dicts = None | ||
try: | ||
response = strip_function_calls(response) | ||
# Parallel function calls returned as a str, list[dict] | ||
if len(response) > 1: | ||
function_call_dicts = [] | ||
for function_call in response: | ||
function_call_dicts.append(parse_function_call(function_call)) | ||
response = ", ".join(response) | ||
# Single function call returned as a str, dict | ||
else: | ||
function_call_dicts = parse_function_call(response[0]) | ||
response = response[0] | ||
except Exception as e: | ||
# Just faithfully return the generated response str to the user | ||
pass | ||
return response, function_call_dicts | ||
|
||
# Device setup | ||
device : str = "cuda:0" if torch.cuda.is_available() else "cpu" | ||
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 | ||
|
||
# Model and tokenizer setup | ||
model_id : str = "gorilla-llm/gorilla-openfunctions-v2" | ||
tokenizer = AutoTokenizer.from_pretrained(model_id) | ||
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True) | ||
|
||
# Move model to device | ||
model.to(device) | ||
|
||
# Pipeline setup | ||
pipe = pipeline( | ||
"text-generation", | ||
model=model, | ||
tokenizer=tokenizer, | ||
max_new_tokens=128, | ||
batch_size=16, | ||
torch_dtype=torch_dtype, | ||
device=device, | ||
) | ||
|
||
# Example usage 1 | ||
# This should return 2 functions with the right argument | ||
query_1: str = "What's the weather like in the two cities of Boston and San Francisco?" | ||
functions_1 = [ | ||
{ | ||
"name": "get_current_weather", | ||
"description": "Get the current weather in a given location", | ||
"parameters": { | ||
"type": "object", | ||
"properties": { | ||
"location": { | ||
"type": "string", | ||
"description": "The city and state, e.g. San Francisco, CA", | ||
}, | ||
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}, | ||
}, | ||
"required": ["location"], | ||
}, | ||
} | ||
] | ||
|
||
# Example usage 2 | ||
# This should return an error since the function cann't help with the prompt | ||
query_2: str = "What is the freezing point of water at a pressure of 10 kPa?" | ||
functions_2 = [{"name": "thermodynamics.calculate_boiling_point", "description": "Calculate the boiling point of a given substance at a specific pressure.", "parameters": {"type": "object", "properties": {"substance": {"type": "string", "description": "The substance for which to calculate the boiling point."}, "pressure": {"type": "number", "description": "The pressure at which to calculate the boiling point."}, "unit": {"type": "string", "description": "The unit of the pressure. Default is 'kPa'."}}, "required": ["substance", "pressure"]}}] | ||
|
||
# Generate prompt and obtain model output | ||
prompt_1 = get_prompt(query_1, functions=functions_1) | ||
output_1 = pipe(prompt_1) | ||
fn_call_string, function_call_dict = format_response(output_1[0]['generated_text']) | ||
print("--------------------") | ||
print(f"Function call strings 1(s): {fn_call_string}") | ||
print("--------------------") | ||
print(f"OpenAI compatible `function_call`: {function_call_dict}") | ||
print("--------------------") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters