Skip to content

Commit

Permalink
Added verbose flag
Browse files Browse the repository at this point in the history
  • Loading branch information
kylevedder committed Apr 27, 2024
1 parent 7b9e48a commit a8f377e
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -193,9 +193,10 @@ def _sanitize_and_validate_inputs(
), f"predictions and ground_truth must have the same length, got {len(predicted_flow.full_flow)} and {len(gt_frame.flow.full_flow)}"

# Validate that all valid gt flow vectors are considered valid in the predictions.
assert np.all(
(predicted_flow.mask & gt_frame.flow.mask) == gt_frame.flow.mask
), "All valid gt flow vectors must be considered valid in the predictions"
if not np.all((predicted_flow.mask & gt_frame.flow.mask) == gt_frame.flow.mask):
print(
f"{gt_frame.log_id} index {gt_frame.log_idx} with timestamp {gt_frame.log_timestamp} missing {np.sum(gt_frame.flow.mask & ~predicted_flow.mask)} points marked valid."
)

# Set the prediction valid flow mask to be the gt flow so everything lines up
predicted_flow.mask = gt_frame.flow.mask
Expand Down
20 changes: 17 additions & 3 deletions scripts/evals/av2_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,14 @@ def _work(
gt_dataset: Argoverse2CausalSceneFlow,
est_dataset: Argoverse2CausalSceneFlow,
evaluator: Evaluator,
verbose: bool = True,
) -> Evaluator:
# Set tqdm bar on the row of the terminal corresponding to the shard index
for idx in tqdm.tqdm(shard_list, position=shard_idx + 1, desc=f"Shard {shard_idx}"):
iterator = shard_list
if verbose:
iterator = tqdm.tqdm(shard_list, position=shard_idx + 1, desc=f"Shard {shard_idx}")

for idx in iterator:
gt_lst = gt_dataset[idx]
est_lst = est_dataset[idx]
assert len(gt_lst) == len(est_lst) == 2, f"GT and estimated lists must have length 2."
Expand All @@ -71,7 +76,9 @@ def _work(


def _work_wrapper(
args: tuple[int, list[int], Argoverse2CausalSceneFlow, Argoverse2CausalSceneFlow, Evaluator]
args: tuple[
int, list[int], Argoverse2CausalSceneFlow, Argoverse2CausalSceneFlow, Evaluator, bool
]
) -> Evaluator:
return _work(*args)

Expand All @@ -85,6 +92,7 @@ def run_eval(
cache_root: Path,
every_kth: int = 5,
eval_type: str = "bucketed_epe",
verbose: bool = True,
) -> None:
assert data_dir.exists(), f"Data directory {data_dir} does not exist."
assert gt_flow_dir.exists(), f"GT flow directory {gt_flow_dir} does not exist."
Expand Down Expand Up @@ -124,7 +132,7 @@ def run_eval(
# Shard the dataset into pieces for each CPU
shard_lists = _make_index_shards(gt_dataset, cpu_count, every_kth)
args_list = [
(shard_idx, shard_list, gt_dataset, est_dataset, dataset_evaluator)
(shard_idx, shard_list, gt_dataset, est_dataset, dataset_evaluator, verbose)
for shard_idx, shard_list in enumerate(shard_lists)
]

Expand Down Expand Up @@ -168,6 +176,11 @@ def run_eval(
default=Path("/tmp/av2_eval_cache/"),
help="Path to the cache root directory",
)
parser.add_argument(
"--quiet",
action="store_true",
help="Suppress output",
)

args = parser.parse_args()

Expand All @@ -180,4 +193,5 @@ def run_eval(
every_kth=args.every_kth,
eval_type=args.eval_type,
cache_root=args.cache_root,
verbose=not args.quiet,
)

0 comments on commit a8f377e

Please sign in to comment.