From aefefd9ff480729a0780bf7b09d72e19dc6265b0 Mon Sep 17 00:00:00 2001 From: vmcru Date: Tue, 15 Oct 2024 16:28:26 -0400 Subject: [PATCH 1/2] add orion tags to best_hparams.yaml file from original --- benchmarks/MOABB/run_hparam_optimization.sh | 2 + benchmarks/MOABB/utils/rewrite.py | 78 +++++++++++++++++++++ 2 files changed, 80 insertions(+) create mode 100644 benchmarks/MOABB/utils/rewrite.py diff --git a/benchmarks/MOABB/run_hparam_optimization.sh b/benchmarks/MOABB/run_hparam_optimization.sh index fd0e209d1..5c9613533 100755 --- a/benchmarks/MOABB/run_hparam_optimization.sh +++ b/benchmarks/MOABB/run_hparam_optimization.sh @@ -450,3 +450,5 @@ scp $best_yaml_file $final_yaml_file echo "The test performance with best hparams is available at $output_folder/best" +# add the orion flags to the best_hparams.yaml file +python utils/rewrite.py $hparams $final_yaml_file \ No newline at end of file diff --git a/benchmarks/MOABB/utils/rewrite.py b/benchmarks/MOABB/utils/rewrite.py new file mode 100644 index 000000000..db60fc648 --- /dev/null +++ b/benchmarks/MOABB/utils/rewrite.py @@ -0,0 +1,78 @@ +#!/usr/bin/python3 +""" +Yaml file rewriter to add orion tags to the best hparams file from original yaml file. +Author +------ +Victor Cruz, 20224 +""" +import argparse +import yaml +import re + +def readargs(): + parser = argparse.ArgumentParser() + parser.add_argument("original_yaml_file", type=str, help="Original yaml file") + parser.add_argument("best_hparams_file", type=str, help="Best hparams file") + args = parser.parse_args() + + # Check if the file paths are valid + if not args.original_yaml_file.endswith(".yaml"): + raise ValueError("Original yaml file must be a yaml file") + if not args.best_hparams_file.endswith(".yaml"): + raise ValueError("Best hparams file must be a yaml file") + return args + +def extract_orion_tags(original_yaml_file): + """ + Function to extract orion tags and variable names from the original yaml file. + Orion tags are comments that start with '# @orion_step'. + """ + orion_tags = {} + tag_pattern = re.compile(r"# @orion_step(\d+):\s*(.*)") + + with open(original_yaml_file, "r") as og_f: + for line in og_f: + # Extract lines that contain Orion tags + tag_match = tag_pattern.search(line.strip()) + if tag_match: + variable_name = line.split(":")[0].strip() # Get the variable name before ":" + tag_info = tag_match.group(0) # Full tag line + orion_tags[variable_name] = tag_info # Store variable and tag info + return orion_tags + +def rewrite_with_orion_tags(original_yaml_file, best_hparams_file): + """ + Function to add orion tags to the best hparams file. + Matches based on the variable name from the original file to the target file. + """ + orion_tags = extract_orion_tags(original_yaml_file) + + # Read the best_hparams YAML file + with open(best_hparams_file, "r") as best_f: + best_hparams_lines = best_f.readlines() + + # Add orion tags to the appropriate lines in the new file + new_best_hparams_lines = [] + for line in best_hparams_lines: + stripped_line = line.strip() + # Extract variable name from the line in the best hparams file + if ":" in stripped_line: + variable_name = stripped_line.split(":")[0].strip() + + # Check if this variable has a corresponding orion tag + if variable_name in orion_tags: + # Append the orion tag to the same line, ensuring there's a space before the comment + line = line.rstrip() + " " + orion_tags[variable_name] + "\n" + new_best_hparams_lines.append(line) + else: + new_best_hparams_lines.append(line) + else: + new_best_hparams_lines.append(line) + + # Write the modified content back to the best_hparams file + with open(best_hparams_file, "w") as best_f: + best_f.writelines(new_best_hparams_lines) + +if __name__ == "__main__": + args = readargs() + rewrite_with_orion_tags(args.original_yaml_file, args.best_hparams_file) From 6349f6301f73454a0ca8edefcdf01f154aef6cfa Mon Sep 17 00:00:00 2001 From: vmcru Date: Tue, 15 Oct 2024 16:33:34 -0400 Subject: [PATCH 2/2] tytpig mistake --- benchmarks/MOABB/utils/rewrite.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/benchmarks/MOABB/utils/rewrite.py b/benchmarks/MOABB/utils/rewrite.py index db60fc648..5378837c7 100644 --- a/benchmarks/MOABB/utils/rewrite.py +++ b/benchmarks/MOABB/utils/rewrite.py @@ -3,7 +3,7 @@ Yaml file rewriter to add orion tags to the best hparams file from original yaml file. Author ------ -Victor Cruz, 20224 +Victor Cruz, 2024 """ import argparse import yaml