diff --git a/scripts/evals/setup_sparse_user_submission.py b/scripts/evals/setup_sparse_user_submission.py index 34afe72..2bea4c6 100644 --- a/scripts/evals/setup_sparse_user_submission.py +++ b/scripts/evals/setup_sparse_user_submission.py @@ -1,5 +1,6 @@ import argparse import math +import shutil import time import zipfile from pathlib import Path @@ -7,9 +8,25 @@ import tqdm -def _unzip_submission(working_dir: Path, submission_zip: Path) -> Path: +def _unzip_submission(working_dir: Path, submission_zip: Path, every_kth: int) -> Path: assert submission_zip.exists(), f"Submission zip {submission_zip} does not exist." submission_dir = working_dir / "submission" + + submission_dir.mkdir(parents=True, exist_ok=False) + + # If the submission zip is actually a directory, symlink it to the submission dir + if submission_zip.is_dir(): + # Iterate over every sequence folder + for sequence_folder in submission_zip.glob("*"): + sequence_folder_name = sequence_folder.name + user_sequence_folder = submission_dir / sequence_folder_name + user_sequence_folder.mkdir(parents=True, exist_ok=False) + for idx, user_file in enumerate(sorted(sequence_folder.glob("*.feather"))): + if idx % every_kth == 0: + shutil.copy(user_file, user_sequence_folder / user_file.name) + + return submission_dir + print(f"Unzipping {submission_zip} to {submission_dir}") before_unzip = time.time() with zipfile.ZipFile(submission_zip, "r") as zip_ref: @@ -73,7 +90,7 @@ def run_setup_sparse_user_submission( working_dir: Path, user_submission_zip: Path, ground_truth_root_folder: Path, - every_kth_entry: int = 5, + every_kth_entry: int, ) -> Path: working_dir = Path(working_dir) user_submission_zip = Path(user_submission_zip) @@ -87,7 +104,7 @@ def run_setup_sparse_user_submission( ground_truth_root_folder.exists() ), f"Ground truth root folder {ground_truth_root_folder} does not exist." - unziped_submission_dir = _unzip_submission(working_dir, user_submission_zip) + unziped_submission_dir = _unzip_submission(working_dir, user_submission_zip, every_kth_entry) # Iterate over the sequence folders and validate and create dummy entries for gt_sequence_folder in tqdm.tqdm(sorted(ground_truth_root_folder.glob("*"))): @@ -119,8 +136,17 @@ def run_setup_sparse_user_submission( type=Path, help="The root folder containing the ground truth sequence folders.", ) + parser.add_argument( + "--every_kth_entry", + type=int, + default=5, + help="The number of entries to skip in the user submission.", + ) args = parser.parse_args() run_setup_sparse_user_submission( - args.working_dir, args.user_submission_zip, args.ground_truth_root_folder + args.working_dir, + args.user_submission_zip, + args.ground_truth_root_folder, + args.every_kth_entry, )