|
| 1 | +from transformers import AutoTokenizer |
| 2 | +import transformers |
| 3 | +import torch |
| 4 | +import re |
| 5 | +import json |
| 6 | +import csv |
| 7 | +import templates |
| 8 | +import subprocess |
| 9 | + |
| 10 | +model = "meta-llama/Llama-2-7b-chat-hf" |
| 11 | + |
| 12 | +JSON_filename = 'PARARULE_plus_step2_People_sample.json' |
| 13 | +PY_filename = 'pyDatalog_processing.py' |
| 14 | + |
| 15 | + |
| 16 | +def remove_spaces(text): |
| 17 | + # Replace multiple spaces with a single space |
| 18 | + text = re.sub(r' +', ' ', text) |
| 19 | + # Remove leading and trailing spaces from each line |
| 20 | + text = re.sub(r'^ +| +$', '', text, flags=re.MULTILINE) |
| 21 | + return text |
| 22 | + |
| 23 | +template = { |
| 24 | + "Llama2_baseline": remove_spaces("""Based on the closed world assumption, please help me complete a multi-step logical reasoning task (judge true or not). Please help me answer whether the question is correct or not based on the facts and rules formed by these natural language propositions. 、 |
| 25 | + You should just return me one number as the final answer (1 for true and 0 for wrong) and providing reasoning process simply. The Propositions and Questions are as follows: \n""") |
| 26 | +} |
| 27 | + |
| 28 | +tokenizer = AutoTokenizer.from_pretrained(model) |
| 29 | +pipeline = transformers.pipeline( |
| 30 | + "text-generation", |
| 31 | + model=model, |
| 32 | + torch_dtype=torch.float16, |
| 33 | + device_map="auto", |
| 34 | +) |
| 35 | + |
| 36 | +def extract_string(input_string): |
| 37 | + left_boundary = 'import' |
| 38 | + right_boundary = ')' |
| 39 | + |
| 40 | + start_index = input_string.find(left_boundary) |
| 41 | + end_index = input_string.rfind(right_boundary, start_index) |
| 42 | + |
| 43 | + if start_index != -1 and end_index != -1: |
| 44 | + extracted_string = input_string[start_index:end_index + 1] |
| 45 | + return extracted_string.strip() |
| 46 | + |
| 47 | + return None |
| 48 | + |
| 49 | + |
| 50 | +def batch_process(text): |
| 51 | + sequences = pipeline( |
| 52 | + text, |
| 53 | + do_sample=True, |
| 54 | + top_k=10, |
| 55 | + num_return_sequences=1, |
| 56 | + eos_token_id=tokenizer.eos_token_id, |
| 57 | + max_length=2048, |
| 58 | + ) |
| 59 | + return sequences[0]['generated_text'] |
| 60 | + |
| 61 | + |
| 62 | +# List of json file names |
| 63 | +json_files = [ |
| 64 | + "../PARARULE_plus_step2_Animal_sample.json", |
| 65 | + "../PARARULE_plus_step3_Animal_sample.json", |
| 66 | + "../PARARULE_plus_step4_Animal_sample.json", |
| 67 | + "../PARARULE_plus_step5_Animal_sample.json", |
| 68 | + "../PARARULE_plus_step2_People_sample.json", |
| 69 | + "../PARARULE_plus_step3_People_sample.json", |
| 70 | + "../PARARULE_plus_step4_People_sample.json", |
| 71 | + "../PARARULE_plus_step5_People_sample.json" |
| 72 | +] |
| 73 | + |
| 74 | +with open(JSON_filename, 'r') as file: |
| 75 | + data = json.load(file) |
| 76 | + |
| 77 | + |
| 78 | +# # Open the CSV file for writing |
| 79 | +# with open("Llama2-7B-ChatLogic.csv", "w", newline="", encoding="utf-8") as csv_file: |
| 80 | +# csv_writer = csv.writer(csv_file) |
| 81 | +# csv_writer.writerow(["step", "return", "label"]) # Write header |
| 82 | +# |
| 83 | +# for json_file in json_files: |
| 84 | +# step = '_'.join(json_file.split("_")[2:4]) |
| 85 | +# with open(json_file, "r", encoding="utf-8") as f: |
| 86 | +# data = json.load(f) |
| 87 | +# for entry in data: |
| 88 | +# context = entry["context"] |
| 89 | +# question = entry["question"] |
| 90 | +# label = entry["label"] |
| 91 | +# # Replace this with your actual function call |
| 92 | +# responses = batch_process(f"Instructions: ```{template['Llama2_baseline']}```Propositions: ```{context}```\nQuestion: ```{question}```") |
| 93 | +# |
| 94 | +# csv_writer.writerow([step, responses, label]) |
| 95 | + |
| 96 | +correct_num = 0 |
| 97 | +for i in range(0, 1): |
| 98 | + try: |
| 99 | + |
| 100 | + # first time generate the code from propositions |
| 101 | + result_string = extract_string(batch_process(f"""{templates.templates['agent_engineer']}, Here are the propositions: {data[i]['context']} and the Question:{data[i]['question']}, |
| 102 | + {templates.templates['no_extra_content']}""")) |
| 103 | + # print(result_string) |
| 104 | + |
| 105 | + # convert code back 2 propositions |
| 106 | + propositions_generated = batch_process(f"""{templates.templates["agent_engineer_neg"]}, and the following is the generated code: {result_string}""") |
| 107 | + |
| 108 | + # Comparison |
| 109 | + # zero-shot CoT is here |
| 110 | + tag = batch_process(f"""{templates.templates['check_error_part1']}, and the original Propositions:{data[i]['context']}, and Question:{data[i]['question']}, the generated Propositions and Questions: {propositions_generated}""") |
| 111 | + tag_final = batch_process(f"""{templates.templates['check_error_part2']}, the following is the analysis processing: {tag}""") |
| 112 | + |
| 113 | + # if it pass the comparison |
| 114 | + if "1" in tag_final: |
| 115 | + flag = 0 |
| 116 | + with open(PY_filename, 'w') as file: |
| 117 | + file.write("{}".format(result_string)) |
| 118 | + output = subprocess.check_output(['python', PY_filename], universal_newlines=True) |
| 119 | + while (output.strip() != '1' and output.strip() != '0'): |
| 120 | + result_string = extract_string(batch_process(f"""{templates.templates['adjustment_agent']}, and here is the generated code: {result_string}, and the error message: {output}""")) |
| 121 | + with open(PY_filename, 'w') as file: |
| 122 | + file.write("{}".format(result_string)) |
| 123 | + print("reprocessing...") |
| 124 | + output = subprocess.check_output(['python', PY_filename], universal_newlines=True) |
| 125 | + print("New output:" + output) |
| 126 | + print(type(output)) |
| 127 | + flag += 1 |
| 128 | + if (flag == 3): |
| 129 | + break |
| 130 | + else: |
| 131 | + print("enter the regeneration part") |
| 132 | + # regenaration |
| 133 | + result_string = extract_string(batch_process(f"""{templates.templates['regeneration']},The original propositions are:{data[i]['context']}, and Question:{data[i]['question']}, and the following is the generated code: {result_string}, and the differences: {tag_final}""")) |
| 134 | + |
| 135 | + with open(PY_filename, 'w') as file: |
| 136 | + file.write("{}".format(result_string)) |
| 137 | + output = subprocess.check_output(['python', PY_filename], universal_newlines=True) |
| 138 | + flag = 0 |
| 139 | + while (output.strip() != '1' and output.strip() != '0'): |
| 140 | + result_string = extract_string(batch_process(f"""{templates.templates['adjustment_agent']}, and here is the generated code: {result_string}, and the error message: {output}""")) |
| 141 | + with open(PY_filename, 'w') as file: |
| 142 | + file.write("{}".format(result_string)) |
| 143 | + print("reprocessing...") |
| 144 | + output = subprocess.check_output(['python', PY_filename], universal_newlines=True) |
| 145 | + print("New output:" + output) |
| 146 | + print(type(output)) |
| 147 | + flag += 1 |
| 148 | + if (flag == 3): |
| 149 | + break |
| 150 | + |
| 151 | + # check correctness |
| 152 | + if (output.strip() != '1' and output.strip() != '0'): |
| 153 | + correct_num += 1 |
| 154 | + if int(output.strip()) == data[i]['label']: |
| 155 | + correct_num += 1 |
| 156 | + else: |
| 157 | + continue |
| 158 | + except Exception as e: |
| 159 | + continue |
| 160 | +print(correct_num) |
0 commit comments