Skip to content

Commit

Permalink
Local Inference script for openfunctions v2 (#242)
Browse files Browse the repository at this point in the history
* Local inference for open-functions-v2.
* Fix: Error in how javascript output was parsed.

---------

Co-authored-by: Charlie Cheng-Jie Ji <[email protected]>
  • Loading branch information
ShishirPatil and CharlieJCJ authored Mar 9, 2024
1 parent 0a1bc0a commit 381c80b
Show file tree
Hide file tree
Showing 4 changed files with 150 additions and 9 deletions.
27 changes: 24 additions & 3 deletions openfunctions/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -119,16 +119,16 @@ This is possible in OpenFunctions v2, because we ensure that the output includes

### End to End Example

In the current directory, run the example code in `ofv2_hosted.py` to see how the model works.
In the current directory, run the example code in `inference_hosted.py` to see how the model works.

```bash
python ofv2_hosted.py
python inference_hosted.py
```

Expected Output:

```bash
(.py3) shishir@dhcp-132-64:~/Work/Gorilla/openfunctions/$ python ofv2_hosted.py
(.py3) shishir@dhcp-132-64:~/Work/Gorilla/openfunctions/$ python inference_hosted.py
--------------------
Function call strings(s): get_current_weather(location='Boston, MA'), get_current_weather(location='San Francisco, CA')
--------------------
Expand Down Expand Up @@ -182,6 +182,8 @@ git clone https://github.com/tree-sitter/tree-sitter-java.git
git clone https://github.com/tree-sitter/tree-sitter-javascript.git
```

Please `git clone` and run example code under current `openfunctions` directory!

And you can use the following code to format the response:

```python
Expand Down Expand Up @@ -220,6 +222,25 @@ def format_response(response: str):

```

### End to End Example

In the current directory, run the example code in `inference_local.py` to see how the model works.

```bash
python inference_local.py
```
Expected Output (if you are using `query_1` and `functions_1`):
```
Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████| 2/2 [00:06<00:00, 3.15s/it]
...
...
--------------------
Function call strings 1(s): get_current_weather(location='Boston, MA'), get_current_weather(location='San Francisco, CA')
--------------------
OpenAI compatible `function_call`: [{'name': 'get_current_weather', 'arguments': {'location': 'Boston, MA'}}, {'name': 'get_current_weather', 'arguments': {'location': 'San Francisco, CA'}}]
--------------------
```

**Note:** Use the `get_prompt` and `format_response` only if you are hosting it locally. If you are using the Berkeley hosted models through the Chat-completion API, we do this in the backend, so you don't have to do this. The model is supported in Hugging Face 🤗 Transformers and can be run up locally:


Expand Down
File renamed without changes.
111 changes: 111 additions & 0 deletions openfunctions/inference_local.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
import json
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from openfunctions_utils import strip_function_calls, parse_function_call

def get_prompt(user_query: str, functions: list = []) -> str:
"""
Generates a conversation prompt based on the user's query and a list of functions.
Parameters:
- user_query (str): The user's query.
- functions (list): A list of functions to include in the prompt.
Returns:
- str: The formatted conversation prompt.
"""
system = "You are an AI programming assistant, utilizing the Gorilla LLM model, developed by Gorilla LLM, and you only answer questions related to computer science. For politically sensitive questions, security and privacy issues, and other non-computer science questions, you will refuse to answer."
if len(functions) == 0:
return f"{system}\n### Instruction: <<question>> {user_query}\n### Response: "
functions_string = json.dumps(functions)
return f"{system}\n### Instruction: <<function>>{functions_string}\n<<question>>{user_query}\n### Response: "


def format_response(response: str):
"""
Formats the response from the OpenFunctions model.
Parameters:
- response (str): The response generated by the LLM.
Returns:
- str: The formatted response.
- dict: The function call(s) extracted from the response.
"""
function_call_dicts = None
try:
response = strip_function_calls(response)
# Parallel function calls returned as a str, list[dict]
if len(response) > 1:
function_call_dicts = []
for function_call in response:
function_call_dicts.append(parse_function_call(function_call))
response = ", ".join(response)
# Single function call returned as a str, dict
else:
function_call_dicts = parse_function_call(response[0])
response = response[0]
except Exception as e:
# Just faithfully return the generated response str to the user
pass
return response, function_call_dicts

# Device setup
device : str = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32

# Model and tokenizer setup
model_id : str = "gorilla-llm/gorilla-openfunctions-v2"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True)

# Move model to device
model.to(device)

# Pipeline setup
pipe = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
max_new_tokens=128,
batch_size=16,
torch_dtype=torch_dtype,
device=device,
)

# Example usage 1
# This should return 2 functions with the right argument
query_1: str = "What's the weather like in the two cities of Boston and San Francisco?"
functions_1 = [
{
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
},
"required": ["location"],
},
}
]

# Example usage 2
# This should return an error since the function cann't help with the prompt
query_2: str = "What is the freezing point of water at a pressure of 10 kPa?"
functions_2 = [{"name": "thermodynamics.calculate_boiling_point", "description": "Calculate the boiling point of a given substance at a specific pressure.", "parameters": {"type": "object", "properties": {"substance": {"type": "string", "description": "The substance for which to calculate the boiling point."}, "pressure": {"type": "number", "description": "The pressure at which to calculate the boiling point."}, "unit": {"type": "string", "description": "The unit of the pressure. Default is 'kPa'."}}, "required": ["substance", "pressure"]}}]

# Generate prompt and obtain model output
prompt_1 = get_prompt(query_1, functions=functions_1)
output_1 = pipe(prompt_1)
fn_call_string, function_call_dict = format_response(output_1[0]['generated_text'])
print("--------------------")
print(f"Function call strings 1(s): {fn_call_string}")
print("--------------------")
print(f"OpenAI compatible `function_call`: {function_call_dict}")
print("--------------------")
21 changes: 15 additions & 6 deletions openfunctions/openfunctions_utils.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
from openfunctions.utils.python_parser import parse_python_function_call
from openfunctions.utils.java_parser import parse_java_function_call
from openfunctions.utils.js_parser import parse_javascript_function_call
from utils.python_parser import parse_python_function_call
from utils.java_parser import parse_java_function_call
from utils.js_parser import parse_javascript_function_call

FN_CALL_DELIMITER = "<<function>>"

def strip_function_calls(content: str) -> list[str]:
"""
Split the content by the function call delimiter and remove empty strings
"""
return [element.strip() for element in content.split(FN_CALL_DELIMITER) if element.strip()]
return [element.strip() for element in content.split(FN_CALL_DELIMITER)[2:] if element.strip()]

def parse_function_call(call: str) -> dict[str, any]:
"""
Expand All @@ -21,7 +21,16 @@ def parse_function_call(call: str) -> dict[str, any]:
except Exception as e:
# If Python parsing fails, try Java parsing
try:
return parse_java_function_call(call)
java_result = parse_java_function_call(call)
if not java_result:
raise Exception("Java parsing failed")
return java_result
except Exception as e:
# If Java parsing also fails, try JavaScript parsing
return parse_javascript_function_call(call)
try:
javascript_result = parse_javascript_function_call(call)
if not javascript_result:
raise Exception("JavaScript parsing failed")
return javascript_result
except:
return None

0 comments on commit 381c80b

Please sign in to comment.