Skip to content

Commit

Permalink
update: refine parsing script
Browse files Browse the repository at this point in the history
  • Loading branch information
terryyz committed Apr 23, 2024
1 parent 52b80bd commit 4ff1c87
Showing 1 changed file with 63 additions and 24 deletions.
87 changes: 63 additions & 24 deletions script/parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,8 @@ def extract_test(file_contents, function_name):
return f"Error processing the script: {e}"


def extract_content(file_path):

def extract_content(file_path, rename_id=None):
data = {"file": file_path.split("/")[-1]}
with open(file_path, 'r') as file:
for line in file:
Expand All @@ -248,14 +249,29 @@ def extract_content(file_path):
data["task_id"] = function_name
break
with open(file_path, "r", encoding="utf-8") as f:
if not rename_id:
rename_id = data["task_id"]
data["task_id"] = rename_id
content = f.read().strip("\n").replace("AxesSubplot", "Axes").replace("matplotlib.axes._subplots", "matplotlib.axes._axes")
content = content.replace(function_name, rename_id)

function_name = rename_id
# Extracting the docstring if present
docstring_start = content.find('"""')
docstring_end = content.find('"""', docstring_start + 3)
if docstring_start == -1 and docstring_end == -1:
dq_docstring_start = content.find('"""')
dq_docstring_end = content.find('"""', dq_docstring_start + 3)
sq_docstring_start = content.find("'''")
if (dq_docstring_start > sq_docstring_start and sq_docstring_start != -1) or dq_docstring_end == -1 or dq_docstring_start == -1:
docstring_start = content.find("'''")
docstring_end = content.find("'''", docstring_start + 3)
else:
docstring_start = dq_docstring_start
docstring_end = dq_docstring_end
# get the nearest "def" before docstring_start
function_name_start = content.rfind("def", 0, docstring_start)
data["signature"] = " ".join(l.strip() for l in content[function_name_start:docstring_start].strip().splitlines())
data["prompt"] = content[:docstring_end + 3]
data["prompt_wo_doc"] = "\n".join(line for line in content[:docstring_start].strip().splitlines() if line)
# print(data["prompt"])
tree = ast.parse(content)
function_end_line = None
for node in ast.walk(tree):
Expand All @@ -273,13 +289,13 @@ def extract_content(file_path):
data["apis"] = extract_apis(data["prompt"] + "\n" + data["canonical_solution"])
data["libs"] = list(set([api.split(".")[0] for api in data["apis"]]))
_, unused_imports = filter_unused_imports(data["prompt"], data["libs"])
if unused_imports:
print(f"Unused imports in {file_path.replace('clean/','raw/')}: {unused_imports}")
# data["test"] = "\n".join(unused_imports) + "\n" + data["test"]
# if unused_imports:
# print(f"Unused imports in {file_path.replace('clean/','raw/')}: {unused_imports}")
docs = re.search(r'\"\"\"(.*?)\"\"\"', data["prompt"], re.DOTALL)
if not docs:
docs = re.search(r"'''(.*?)'''", data["prompt"], re.DOTALL)
data['doc'] = parse_docstring(docs.group(1))
data["doc"] = parse_docstring(docs.group(1))
data["instruction"] = get_instruction_prompt(data)
return data

def count_return_values(function_code):
Expand All @@ -306,12 +322,12 @@ def find_returns(node):
def parse_docstring(docstring):
sections = {
'description': [],
'note': [],
'notes': [],
'params': [],
'returns': [],
'reqs': [],
'raises': [],
'example': []
'examples': []
}
# Split the docstring into lines and strip whitespace
lines = [line.strip() for line in docstring.strip().split('\n')]
Expand All @@ -320,8 +336,9 @@ def parse_docstring(docstring):
replace_word = ""

for line in lines:
if line.startswith('Note:'):
current_section = 'note'
line = line.strip()
if line.startswith('Note:') or line.startswith('Notes:'):
current_section = 'notes'
replace_word = 'Note:'
elif line.startswith('Parameters:'):
current_section = 'params'
Expand All @@ -336,7 +353,7 @@ def parse_docstring(docstring):
current_section = 'raises'
replace_word = 'Raises:'
elif line.startswith('Example:') or line.startswith('Examples:'):
current_section = 'example'
current_section = 'examples'
replace_word = 'Example:'
elif line and not current_section:
current_section = 'description'
Expand All @@ -346,7 +363,7 @@ def parse_docstring(docstring):
if current_section and current_section != 'description':
reformat_line = line.replace(replace_word, '')
if reformat_line:
if current_section != 'example':
if current_section != 'examples':
reformat_line = reformat_line.strip('- ')
sections[current_section].append(reformat_line)
elif current_section and current_section == 'description':
Expand All @@ -361,6 +378,16 @@ def parse_docstring(docstring):
def reconstruct_problem(data):
return data["prompt"] + "\n" + data["canonical_solution"] + "\n\n" + data["test"] + "\n"

def get_instruction_prompt(data):
base = "Write a function called " + f'`{data["signature"]}` to: ' + " ".join(data["doc"]["description"])
if data["doc"]["notes"]:
base += "\nNote that: " + " ".join(data["doc"]["notes"])
if data["doc"]["raises"]:
base += "\nThe function should raise the exception for: " + " ".join(data["doc"]["raises"])
base += "\nThe function should output with:\n " +\
"\n ".join(data["doc"]["returns"]) + "\nYou should start with:\n```\n" + data["prompt_wo_doc"] + "\n```"
return base

def check_test_wo_doc(data):
"Check if the problem is related to file system, network requests and database"

Expand All @@ -385,27 +412,39 @@ def validate_lib_num(data):
return True

def validate_doc_example(data):
if not data["doc"]["example"]:
if not data["doc"]["examples"]:
return False
return True

def validate_doc_returns(data):
if not data["doc"]["returns"]:
return False
return True

def validate_doc_reqs(data):
if not data["doc"]["reqs"]:
return False
return True



if __name__ == "__main__":
shutil.rmtree("data/processed", ignore_errors=True)
os.makedirs("data/processed")
os.makedirs("data/processed", exist_ok=True)
with open("data/open-eval.jsonl", "w") as f:
for file in tqdm(glob("data/clean/*.py")):
if "ming" in file:
continue
data = extract_content(file)
# print(data["apis"])
# assert validate_lib_num(data), f"Less than 2 libraries are used in {file.replace('clean/', 'raw/')}"
# assert validate_doc_example(data), f"Example is missing in {file.replace('clean/', 'raw/')}"
for i, file in enumerate(tqdm(glob("data/clean/*.py"))):

data = extract_content(file, None)
if not validate_lib_num(data):
print(file.replace('clean/', 'raw/'), "Less than 2 libraries are used")
if not validate_doc_example(data):
print(file.replace('clean/', 'raw/'), "Example is missing")
if not validate_doc_returns(data):
print(file.replace('clean/', 'raw/'), "Returns is missing")
if not validate_doc_reqs(data):
print(file.replace('clean/', 'raw/'), "Requirements is missing")
f.write(json.dumps(data) + "\n")
file_name = file.split("/")[-1].split(".")[0]
file_name = file_name + "_wo_doc" if check_test_wo_doc(data) else file_name + "_w_doc"
with open(f"data/processed/{file_name}.py", "w") as f2:
f2.write(reconstruct_problem(data))
f2.write(reconstruct_problem(data))

0 comments on commit 4ff1c87

Please sign in to comment.