From a0b880f166c821ada525283fcfcc952d75be15c6 Mon Sep 17 00:00:00 2001 From: Yifan Mai Date: Tue, 28 May 2024 10:37:10 -0700 Subject: [PATCH] Add FinQA Scenario (#2588) --- src/helm/benchmark/metrics/fin_qa_metrics.py | 60 +++ .../metrics/fin_qa_metrics_helper.py | 398 ++++++++++++++++++ .../benchmark/run_specs/finance_run_specs.py | 33 ++ .../benchmark/scenarios/fin_qa_scenario.py | 117 +++++ src/helm/benchmark/static/schema_finance.yaml | 143 +++++++ 5 files changed, 751 insertions(+) create mode 100644 src/helm/benchmark/metrics/fin_qa_metrics.py create mode 100644 src/helm/benchmark/metrics/fin_qa_metrics_helper.py create mode 100644 src/helm/benchmark/run_specs/finance_run_specs.py create mode 100644 src/helm/benchmark/scenarios/fin_qa_scenario.py create mode 100644 src/helm/benchmark/static/schema_finance.yaml diff --git a/src/helm/benchmark/metrics/fin_qa_metrics.py b/src/helm/benchmark/metrics/fin_qa_metrics.py new file mode 100644 index 00000000000..04ce963f8a5 --- /dev/null +++ b/src/helm/benchmark/metrics/fin_qa_metrics.py @@ -0,0 +1,60 @@ +import math +import json +from typing import List, Union + +from helm.benchmark.adaptation.adapter_spec import AdapterSpec +from helm.benchmark.adaptation.request_state import RequestState +from helm.benchmark.metrics.metric import Metric +from helm.benchmark.metrics.metric_name import MetricName +from helm.benchmark.metrics.metric_service import MetricService +from helm.benchmark.metrics.statistic import Stat +from helm.benchmark.metrics.fin_qa_metrics_helper import ( # type: ignore + equal_program, + eval_program, + program_tokenization, +) + + +def _get_program_accuracy(reference_program: List[str], generated_program: List[str]) -> float: + return 1.0 if equal_program(reference_program, generated_program) else 0.0 + + +def _get_execution_accuracy(reference_execution: str, generated_program: List[str], table: List[List[str]]) -> float: + invalid_flag: int + generated_result: Union[str, float] + invalid_flag, generated_result = eval_program(generated_program, table) + if invalid_flag: + return 0.0 + if reference_execution == "yes" or reference_execution == "no": + return 1.0 if reference_execution == generated_result else 0 + else: + if not isinstance(generated_result, float): + return 0.0 + return 1.0 if math.isclose(float(reference_execution), generated_result) else 0 + + +class FinQAMetric(Metric): + def evaluate_generation( + self, + adapter_spec: AdapterSpec, + request_state: RequestState, + metric_service: MetricService, + eval_cache_path: str, + ) -> List[Stat]: + assert len(request_state.instance.references) == 3 + reference_text = request_state.instance.references[0].output.text + reference_program = program_tokenization(reference_text) + reference_execution = request_state.instance.references[1].output.text + table: List[List[str]] = json.loads(request_state.instance.references[2].output.text) + + assert request_state.result + assert len(request_state.result.completions) == 1 + generated_text = request_state.result.completions[0].text.strip() + generated_program = program_tokenization(generated_text) + + return [ + Stat(MetricName("program_accuracy")).add(_get_program_accuracy(reference_program, generated_program)), + Stat(MetricName("execution_accuracy")).add( + _get_execution_accuracy(reference_execution, generated_program, table) + ), + ] diff --git a/src/helm/benchmark/metrics/fin_qa_metrics_helper.py b/src/helm/benchmark/metrics/fin_qa_metrics_helper.py new file mode 100644 index 00000000000..c735254ffb3 --- /dev/null +++ b/src/helm/benchmark/metrics/fin_qa_metrics_helper.py @@ -0,0 +1,398 @@ +# type: ignore +# flake8: noqa +# fmt: off +"""Evaluation metrics for FinQA. + +This evaluation code is reproduced from the following URL with the following license. + +URL: https://github.com/czyssrs/FinQA/blob/0f16e2867befa6840783e58be38c9efb9229d742/code/evaluate/evaluate.py + +License: MIT License + +Copyright (c) 2021 Zhiyu Chen + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE.""" + +from sympy import simplify + + +all_ops = ["add", "subtract", "multiply", "divide", "exp", "greater", "table_max", \ +"table_min", "table_sum", "table_average"] + + +def str_to_num(text): + + text = text.replace(",", "") + try: + num = float(text) + except ValueError: + if "%" in text: + text = text.replace("%", "") + try: + num = float(text) + num = num / 100.0 + except ValueError: + num = "n/a" + elif "const" in text: + text = text.replace("const_", "") + if text == "m1": + text = "-1" + num = float(text) + else: + num = "n/a" + return num + +def process_row(row_in): + + row_out = [] + invalid_flag = 0 + + for num in row_in: + num = num.replace("$", "").strip() + num = num.split("(")[0].strip() + + num = str_to_num(num) + + if num == "n/a": + invalid_flag = 1 + break + + row_out.append(num) + + if invalid_flag: + return "n/a" + + return row_out + + +def eval_program(program, table): + ''' + calculate the numerical results of the program + ''' + + invalid_flag = 0 + this_res = "n/a" + + try: + program = program[:-1] # remove EOF + # check structure + for ind, token in enumerate(program): + if ind % 4 == 0: + if token.strip("(") not in all_ops: + return 1, "n/a" + if (ind + 1) % 4 == 0: + if token != ")": + return 1, "n/a" + + + program = "|".join(program) + steps = program.split(")")[:-1] + + + res_dict = {} + + # print(program) + + for ind, step in enumerate(steps): + step = step.strip() + + if len(step.split("(")) > 2: + invalid_flag = 1 + break + op = step.split("(")[0].strip("|").strip() + args = step.split("(")[1].strip("|").strip() + + # print(args) + # print(op) + + arg1 = args.split("|")[0].strip() + arg2 = args.split("|")[1].strip() + + if op == "add" or op == "subtract" or op == "multiply" or op == "divide" or op == "exp" or op == "greater": + + if "#" in arg1: + arg1 = res_dict[int(arg1.replace("#", ""))] + else: + # print(arg1) + arg1 = str_to_num(arg1) + if arg1 == "n/a": + invalid_flag = 1 + break + + if "#" in arg2: + arg2 = res_dict[int(arg2.replace("#", ""))] + else: + arg2 = str_to_num(arg2) + if arg2 == "n/a": + invalid_flag = 1 + break + + if op == "add": + this_res = arg1 + arg2 + elif op == "subtract": + this_res = arg1 - arg2 + elif op == "multiply": + this_res = arg1 * arg2 + elif op == "divide": + this_res = arg1 / arg2 + elif op == "exp": + this_res = arg1 ** arg2 + elif op == "greater": + this_res = "yes" if arg1 > arg2 else "no" + + + # print("ind: ", ind) + # print(this_res) + res_dict[ind] = this_res + + + elif "table" in op: + table_dict = {} + for row in table: + table_dict[row[0]] = row[1:] + + if "#" in arg1: + arg1 = res_dict[int(arg1.replace("#", ""))] + else: + if arg1 not in table_dict: + invalid_flag = 1 + break + + cal_row = table_dict[arg1] + num_row = process_row(cal_row) + + if num_row == "n/a": + invalid_flag = 1 + break + if op == "table_max": + this_res = max(num_row) + elif op == "table_min": + this_res = min(num_row) + elif op == "table_sum": + this_res = sum(num_row) + elif op == "table_average": + this_res = sum(num_row) / len(num_row) + + # this_res = round(this_res, 5) + + res_dict[ind] = this_res + + # print(this_res) + + if this_res != "yes" and this_res != "no" and this_res != "n/a": + # print(this_res) + this_res = round(this_res, 5) + + except: + invalid_flag = 1 + + + return invalid_flag, this_res + + +def equal_program(program1, program2): + ''' + symbolic program if equal + program1: gold + program2: pred + ''' + + sym_map = {} + + program1 = program1[:-1] # remove EOF + program1 = "|".join(program1) + steps = program1.split(")")[:-1] + + invalid_flag = 0 + sym_ind = 0 + step_dict_1 = {} + + # symbolic map + for ind, step in enumerate(steps): + + step = step.strip() + + assert len(step.split("(")) <= 2 + + op = step.split("(")[0].strip("|").strip() + args = step.split("(")[1].strip("|").strip() + + arg1 = args.split("|")[0].strip() + arg2 = args.split("|")[1].strip() + + step_dict_1[ind] = step + + if "table" in op: + if step not in sym_map: + sym_map[step] = "a" + str(sym_ind) + sym_ind += 1 + + else: + if "#" not in arg1: + if arg1 not in sym_map: + sym_map[arg1] = "a" + str(sym_ind) + sym_ind += 1 + + if "#" not in arg2: + if arg2 not in sym_map: + sym_map[arg2] = "a" + str(sym_ind) + sym_ind += 1 + + + # check program 2 + step_dict_2 = {} + try: + program2 = program2[:-1] # remove EOF + # check structure + for ind, token in enumerate(program2): + if ind % 4 == 0: + if token.strip("(") not in all_ops: + print("structure error") + return False + if (ind + 1) % 4 == 0: + if token != ")": + print("structure error") + return False + + program2 = "|".join(program2) + steps = program2.split(")")[:-1] + + for ind, step in enumerate(steps): + step = step.strip() + + if len(step.split("(")) > 2: + return False + op = step.split("(")[0].strip("|").strip() + args = step.split("(")[1].strip("|").strip() + + # print(args) + # print(op) + + arg1 = args.split("|")[0].strip() + arg2 = args.split("|")[1].strip() + + step_dict_2[ind] = step + + if "table" in op: + if step not in sym_map: + return False + + else: + if "#" not in arg1: + if arg1 not in sym_map: + return False + else: + if int(arg1.strip("#")) >= ind: + return False + + if "#" not in arg2: + if arg2 not in sym_map: + return False + else: + if int(arg2.strip("#")) >= ind: + return False + except: + return False + + def symbol_recur(step, step_dict): + + step = step.strip() + op = step.split("(")[0].strip("|").strip() + args = step.split("(")[1].strip("|").strip() + + arg1 = args.split("|")[0].strip() + arg2 = args.split("|")[1].strip() + + # print(op) + # print(arg1) + # print(arg2) + + if "table" in op: + # as var + return sym_map[step] + + if "#" in arg1: + arg1_ind = int(arg1.replace("#", "")) + arg1_part = symbol_recur(step_dict[arg1_ind], step_dict) + else: + arg1_part = sym_map[arg1] + + + if "#" in arg2: + arg2_ind = int(arg2.replace("#", "")) + arg2_part = symbol_recur(step_dict[arg2_ind], step_dict) + else: + arg2_part = sym_map[arg2] + + if op == "add": + return "( " + arg1_part + " + " + arg2_part + " )" + elif op == "subtract": + return "( " + arg1_part + " - " + arg2_part + " )" + elif op == "multiply": + return "( " + arg1_part + " * " + arg2_part + " )" + elif op == "divide": + return "( " + arg1_part + " / " + arg2_part + " )" + elif op == "exp": + return "( " + arg1_part + " ** " + arg2_part + " )" + elif op == "greater": + return "( " + arg1_part + " > " + arg2_part + " )" + + + # # derive symbolic program 1 + # print(program1) + steps = program1.split(")")[:-1] + # print(steps) + # print(steps) + # print(sym_map) + sym_prog1 = symbol_recur(steps[-1], step_dict_1) + sym_prog1 = simplify(sym_prog1, evaluate=False) + # print("########") + # print(sym_prog1) + + try: + # derive symbolic program 2 + steps = program2.split(")")[:-1] + sym_prog2 = symbol_recur(steps[-1], step_dict_2) + sym_prog2 = simplify(sym_prog2, evaluate=False) + # print(sym_prog2) + except: + return False + + return sym_prog1 == sym_prog2 + + +def program_tokenization(original_program): + original_program = original_program.split(', ') + program = [] + for tok in original_program: + cur_tok = '' + for c in tok: + if c == ')': + if cur_tok != '': + program.append(cur_tok) + cur_tok = '' + cur_tok += c + if c in ['(', ')']: + program.append(cur_tok) + cur_tok = '' + if cur_tok != '': + program.append(cur_tok) + program.append('EOF') + return program +# fmt: on diff --git a/src/helm/benchmark/run_specs/finance_run_specs.py b/src/helm/benchmark/run_specs/finance_run_specs.py new file mode 100644 index 00000000000..f0b9b57a337 --- /dev/null +++ b/src/helm/benchmark/run_specs/finance_run_specs.py @@ -0,0 +1,33 @@ +"""Run spec functions for the HELM Finance leaderboard. + +Website: https://crfm.stanford.edu/helm/finance/""" + +from helm.benchmark.adaptation.common_adapter_specs import ( + get_generation_adapter_spec, +) +from helm.benchmark.metrics.common_metric_specs import ( + get_basic_metric_specs, +) +from helm.benchmark.metrics.metric import MetricSpec +from helm.benchmark.run_spec import RunSpec, run_spec_function +from helm.benchmark.scenarios.scenario import ScenarioSpec + + +@run_spec_function("fin_qa") +def get_fin_qa_spec() -> RunSpec: + from helm.benchmark.scenarios.fin_qa_scenario import INSTRUCTIONS + + scenario_spec = ScenarioSpec(class_name="helm.benchmark.scenarios.fin_qa_scenario.FinQAScenario", args={}) + adapter_spec = get_generation_adapter_spec( + instructions=INSTRUCTIONS, input_noun=None, output_noun="Program", max_tokens=100 + ) + metric_specs = get_basic_metric_specs([]) + [ + MetricSpec(class_name="helm.benchmark.metrics.fin_qa_metrics.FinQAMetric") + ] + return RunSpec( + name="fin_qa", + scenario_spec=scenario_spec, + adapter_spec=adapter_spec, + metric_specs=metric_specs, + groups=["fin_qa"], + ) diff --git a/src/helm/benchmark/scenarios/fin_qa_scenario.py b/src/helm/benchmark/scenarios/fin_qa_scenario.py new file mode 100644 index 00000000000..8d187255a32 --- /dev/null +++ b/src/helm/benchmark/scenarios/fin_qa_scenario.py @@ -0,0 +1,117 @@ +import os +import json +from typing import List + +from helm.common.general import ensure_directory_exists, ensure_file_downloaded +from helm.benchmark.scenarios.scenario import ( + Scenario, + Instance, + Input, + Output, + Reference, + TRAIN_SPLIT, + TEST_SPLIT, + CORRECT_TAG, +) + + +DATASET_URL_PREFIX = "https://github.com/czyssrs/FinQA/raw/0f16e2867befa6840783e58be38c9efb9229d742/dataset/" +INSTRUCTIONS = """Presented with a financial report consisting of textual contents and a structured table, given a question, generate the reasoning program in the domain specific langauge (DSL) that will be executed to get the answer. + +The DSL consists of mathematical operations and table operations as executable programs. The program consists of a sequence of operations. Each operation takes a list of arguments. + +There are 6 mathematical operations: add, subtract, multiply, divide, greater, exp, and 4 table aggregation operations table-max, table-min, table-sum, table-average, that apply aggregation operations on table rows. The mathematical operations take arguments of either numbers from the given reports, or a numerical result from a previous step. + +The table operations take arguments of table row names. We use the special token #n to denote the result from the nth step. + +For example, in the example "divide(9413, 20.01), divide(8249, 9.48), subtract(#0, #1)", the program consists of 3 steps; The first and the second division steps take arguments from the table and the text, respectively, then the third step subtracts the results from the two previous steps. + +Definitions of all operations: + +[["Name", "Arguments", "Output", "Description"], +["add", "number1, number2", "number", "add two numbers: number1 + number2"], +["subtract", "number1, number2", "number", "subtract two numbers: number1 − number2"], +["multiply", "number1, number2", "number", "multiply two numbers: number1 * number2"], +["divide", "number1, number2", "number", "multiply two numbers: number1 / number2"], +["exp", "number1, number2", "number", "exponential: number1 ^ number2"], +["greater", "number1, number2", "bool", "comparison: number1 > number2"], +["table-sum", "table header", "number", "the summation of one table row"], +["table-average", "table header", "number", "the average of one table row"], +["table-max", "table header", "number", "the maximum number of one table row"], +["table-min", "table header", "number", "the minimum number of one table row"]] + +Answer with only the program, without any additional explanation. +""" # noqa: E501 + + +class FinQAScenario(Scenario): + """ + FinQA is a question answering task over financial reports that requires robust numerical reasoning. + + FinQA: A Dataset of Numerical Reasoning over Financial Data + Paper: https://arxiv.org/abs/2109.00122 + Code: https://github.com/czyssrs/FinQA + + Presented with a financial report consisting of textual contents and a structured table, given a question, + the task is togenerate the reasoning program in the domain specific langauge (DSL) that will be executed + to get the answer. + + We add the sub-headers "Pre-table text", "Table", "Post-table text" to the input. Example: + + ``` + Pre-table text: printing papers net sales for 2006 decreased 3% ( 3 % ) from both 2005 and 2004 due principally... + [more lines] + Table: [["in millions", "2006", "2005", "2004"], ["sales", "$ 6930", "$ 7170", "$ 7135"], ["operating profit", "$ 677", "$ 473", "$ 508"]] + Post-table text: u.s . + uncoated papers net sales in 2006 were $ 3.5 billion , compared with $ 3.2 billion in 2005 and $ 3.3 billion in 2004 . + [more lines] + Question: brazilian paper sales represented what percentage of printing papers in 2005? + Program: + ``` + """ # noqa: E501 + + name = "fin_qa" + description = "FinQA" + tags = ["question_answering", "financial"] + + def get_instances(self, output_path: str) -> List[Instance]: + data_path: str = os.path.join(output_path, "data") + ensure_directory_exists(data_path) + # Note: only train and test splits are used; dev split is not used + instances: List[Instance] = [] + for split in [TRAIN_SPLIT, TEST_SPLIT]: + file_name = f"{split}.json" + target_path = os.path.join(data_path, file_name) + ensure_file_downloaded( + source_url=DATASET_URL_PREFIX + file_name, + target_path=target_path, + ) + with open(target_path, "r") as f: + rows = json.load(f) + for row in rows: + pre_text = "Pre-table text: " + "\n".join(row["pre_text"]) + table = "Table: " + json.dumps(row["table"]) + post_text = "Post-table text: " + "\n".join(row["post_text"]) + question = "Question: " + row["qa"]["question"] + text = "\n".join([pre_text, table, post_text, question]) + references = [ + Reference( + Output(text=str(row["qa"]["program"])), + tags=[CORRECT_TAG], + ), + Reference( + Output(text=str(row["qa"]["exe_ans"])), + tags=[], + ), + Reference( + Output(text=json.dumps(row["table"])), + tags=[], + ), + ] + instance: Instance = Instance( + input=Input(text=text), + references=references, + split=split, + ) + instances.append(instance) + return instances diff --git a/src/helm/benchmark/static/schema_finance.yaml b/src/helm/benchmark/static/schema_finance.yaml new file mode 100644 index 00000000000..6a68780d726 --- /dev/null +++ b/src/helm/benchmark/static/schema_finance.yaml @@ -0,0 +1,143 @@ +--- +############################################################ +metrics: + # Infrastructure metrics: + - name: num_perplexity_tokens + display_name: '# tokens' + description: Average number of tokens in the predicted output (for language modeling, the input too). + - name: num_bytes + display_name: '# bytes' + description: Average number of bytes in the predicted output (for language modeling, the input too). + + - name: num_references + display_name: '# ref' + description: Number of references. + - name: num_train_trials + display_name: '# trials' + description: Number of trials, where in each trial we choose an independent, random set of training instances. + - name: estimated_num_tokens_cost + display_name: 'cost' + description: An estimate of the number of tokens (including prompt and output completions) needed to perform the request. + - name: num_prompt_tokens + display_name: '# prompt tokens' + description: Number of tokens in the prompt. + - name: num_prompt_characters + display_name: '# prompt chars' + description: Number of characters in the prompt. + - name: num_completion_tokens + display_name: '# completion tokens' + description: Actual number of completion tokens (over all completions). + - name: num_output_tokens + display_name: '# output tokens' + description: Actual number of output tokens. + - name: max_num_output_tokens + display_name: 'Max output tokens' + description: Maximum number of output tokens (overestimate since we might stop earlier due to stop sequences). + - name: num_requests + display_name: '# requests' + description: Number of distinct API requests. + - name: num_instances + display_name: '# eval' + description: Number of evaluation instances. + - name: num_train_instances + display_name: '# train' + description: Number of training instances (e.g., in-context examples). + - name: prompt_truncated + display_name: truncated + description: Fraction of instances where the prompt itself was truncated (implies that there were no in-context examples). + - name: finish_reason_length + display_name: finish b/c length + description: Fraction of instances where the the output was terminated because of the max tokens limit. + - name: finish_reason_stop + display_name: finish b/c stop + description: Fraction of instances where the the output was terminated because of the stop sequences. + - name: finish_reason_endoftext + display_name: finish b/c endoftext + description: Fraction of instances where the the output was terminated because the end of text token was generated. + - name: finish_reason_unknown + display_name: finish b/c unknown + description: Fraction of instances where the the output was terminated for unknown reasons. + - name: num_completions + display_name: '# completions' + description: Number of completions. + - name: predicted_index + display_name: Predicted index + description: Integer index of the reference (0, 1, ...) that was predicted by the model (for multiple-choice). + + # Accuracy metrics: + - name: program_accuracy + display_name: Program Accuracy + description: Accuracy of the generated programs + lower_is_better: false + - name: execution_accuracy + display_name: Execution Accuracy + description: Accuracy of the final result of the generated program + lower_is_better: false + +############################################################ +perturbations: [] + +############################################################ +metric_groups: + - name: accuracy + display_name: Accuracy + metrics: + - name: ${main_name} + split: ${main_split} + + - name: efficiency + display_name: Efficiency + metrics: + - name: inference_runtime + split: ${main_split} + + - name: general_information + display_name: General information + hide_win_rates: true + metrics: + - name: num_instances + split: ${main_split} + - name: num_train_instances + split: ${main_split} + - name: prompt_truncated + split: ${main_split} + - name: num_prompt_tokens + split: ${main_split} + - name: num_output_tokens + split: ${main_split} + +############################################################ +run_groups: + - name: financial_scenarios + display_name: Financial Scenarios + description: Scenarios for the financial domain + category: All scenarios + subgroups: + - fin_qa + + - name: fin_qa + display_name: FinQA + description: The FinQA benchmark for numeric reasoning over financial data, with question answering pairs written by financial experts over financial reports [(Chen et al., 2021)](https://arxiv.org/abs/2109.00122/). + metric_groups: + - accuracy + - efficiency + - general_information + environment: + main_name: program_accuracy + main_split: test + taxonomy: + task: question answering with numeric reasoning + what: financial reports + who: financial experts + when: 1999 to 2019 + language: English + + - name: financial_scenarios_ablations + display_name: Financial Scenarios Ablations + description: Scenarios for the financial domain with ablations + category: All scenarios + subgroups: + - fin_qa + adapter_keys_shown: + - model + - max_train_instances