From a5068af9aa4e6a668bbc4ea18034c2f1e446aa7f Mon Sep 17 00:00:00 2001 From: Fanjia-Yan <78303449+Fanjia-Yan@users.noreply.github.com> Date: Mon, 11 Mar 2024 16:36:42 -0700 Subject: [PATCH] [Update Gemini-1.0-Pro result checker] (#245) This PR adds checker for Gemini-1.0-pro result. I have tested on AST and Executable test suite to make sure that the statistics and result of leaderboard matches. --------- Co-authored-by: Shishir Patil <30296397+ShishirPatil@users.noreply.github.com> --- .../openfunctions_ast_checker.py | 8 +++-- .../openfunctions_executable_checker.py | 34 ++++++++++++------- 2 files changed, 28 insertions(+), 14 deletions(-) diff --git a/berkeley-function-call-leaderboard/openfunctions_ast_checker.py b/berkeley-function-call-leaderboard/openfunctions_ast_checker.py index 58ea2c04e..f6b6f1bc1 100644 --- a/berkeley-function-call-leaderboard/openfunctions_ast_checker.py +++ b/berkeley-function-call-leaderboard/openfunctions_ast_checker.py @@ -295,7 +295,7 @@ def ast_parse(x): # Check if the input file is a gpt or gorilla file # Here we are receiving JSON schema where we do the AST checking. -if "gpt" in file_name or "glaive" in file_name or "fire" in file_name or "mistral-large-latest" in file_name: +if "gpt" in file_name or "glaive" in file_name or "fire" in file_name or "mistral-large-latest" in file_name or "gemini" in file_name: total = 0 success = 0 for k in range(len(example)): @@ -335,7 +335,11 @@ def ast_parse(x): keyword = "text" else: keyword = "result" - if len(example[k][keyword]) != len(answer[k].keys()): + try: + if len(example[k][keyword]) != len(answer[k].keys()): + total += 1 + continue + except: total += 1 continue for item in example[k][keyword]: diff --git a/berkeley-function-call-leaderboard/openfunctions_executable_checker.py b/berkeley-function-call-leaderboard/openfunctions_executable_checker.py index 076ff9877..8298d3601 100644 --- a/berkeley-function-call-leaderboard/openfunctions_executable_checker.py +++ b/berkeley-function-call-leaderboard/openfunctions_executable_checker.py @@ -38,20 +38,31 @@ def convert_to_function_call(data_str): output_func_list = [] for func in data_str: # Step 2: Split the string into function name and parameters parts - func_name_part, params_part = func.split(":", 1) + try: + func_name_part, params_part = func.split(":", 1) + except: + continue # Step 3: Clean and extract the function name func_name = func_name_part.strip("{").strip(" '") - # Step 4: Extract and clean the parameters string - if params_part[-1] == "}": - params_part = params_part[:-2] - params_str = params_part.strip(" '") - # Step 5: Replace single quotes with double quotes for JSON parsing - params_str = params_str.replace("'", '"') - params_str = params_str.replace("\\" + "n", "") - # Step 6: Load the parameters string as a dictionary - params = json.loads(params_str) + if "gemini" in model_name: + if params_part[0] != "{": + params_part = params_part[1:] + params_part = params_part[0:-1] + try: + params = eval(params_part) + except: + continue + else: + if params_part[-1] == "}": + params_part = params_part[:-2] + params_str = params_part.strip(" '") + # Step 5: Replace single quotes with double quotes for JSON parsing + params_str = params_str.replace("'", '"') + params_str = params_str.replace("\\" + "n", "") + # Step 6: Load the parameters string as a dictionary + params = json.loads(params_str) function_string = func_name + "(" for k,v in params.items(): if isinstance(v, str): @@ -117,7 +128,7 @@ def convert_to_function_call(data_str): execution_result_type = testing_data[i]["execution_result_type"] if type(execution_result_type) is str and len(execution_result) > 1: execution_result_type = [execution_result_type] * len(execution_result) - if ("gpt" in model_name or "fire" in model_name or "mistral-large-latest" in model_name) and input_file is None: + if ("gpt" in model_name or "fire" in model_name or "mistral-large-latest" in model_name or "gemini" in model_name) and input_file is None: try: result = convert_to_function_call(result_data[i]["result"]) except: @@ -125,7 +136,6 @@ def convert_to_function_call(data_str): continue elif input_file is not None and "gorilla" in input_file: result = result_data[i]["text"] - print(result) elif input_file is not None and "gemma" in input_file: pattern = re.compile(r"\b\w+(?:\.\w+)?\b\([^)]*\)")