Skip to content

Commit

Permalink
v1.0.3: Fixed Threeway EPE bug, updated AV2 Eval script to support ei…
Browse files Browse the repository at this point in the history
…ther eval protocol
  • Loading branch information
kylevedder committed Mar 1, 2024
1 parent 1ce36b5 commit a86047d
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 10 deletions.
26 changes: 17 additions & 9 deletions bucketed_scene_flow_eval/eval/bucketed_epe.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,18 +165,21 @@ def merge_matrix_classes(self, meta_class_lookup: dict[str, list[str]]) -> "Buck

return merged_matrix

def get_mean_average_values(self) -> OverallError:
overall_errors = self.get_overall_class_errors()
def get_mean_average_values(self, normalized: bool = True) -> OverallError:
overall_errors = self.get_overall_class_errors(normalized=normalized)

average_static_epe = np.nanmean([v.static_epe for v in overall_errors.values()])
average_dynamic_error = np.nanmean([v.dynamic_error for v in overall_errors.values()])

return OverallError(average_static_epe, average_dynamic_error)

def to_full_latex(self) -> str:
error_matrix = self.get_normalized_error_matrix()
def to_full_latex(self, normalized: bool = True) -> str:
if normalized:
error_matrix = self.get_normalized_error_matrix()
else:
error_matrix = self.epe_storage_matrix.copy()
# First, get the average class values
average_class_values = self.get_overall_class_errors()
average_class_values = self.get_overall_class_errors(normalized=normalized)

# Define the header row with the speed buckets and the beginning of the tabular environment
column_format = (
Expand Down Expand Up @@ -272,7 +275,9 @@ def _build_stat_table(

return matrix

def _save_stats_tables(self, average_stats: dict[BaseSplitKey, BaseSplitValue]):
def _save_stats_tables(
self, average_stats: dict[BaseSplitKey, BaseSplitValue], normalized: bool = True
):
super()._save_stats_tables(average_stats)

# Compute averages over the speed buckets
Expand All @@ -295,19 +300,22 @@ def _save_stats_tables(self, average_stats: dict[BaseSplitKey, BaseSplitValue]):
)

# Save the raw table
save_txt(full_table_save_path, matrix.to_full_latex())
save_txt(full_table_save_path, matrix.to_full_latex(normalized=normalized))

# Save the per-class results
save_json(
per_class_save_path,
{str(k): str(v) for k, v in matrix.get_overall_class_errors().items()},
{
str(k): str(v)
for k, v in matrix.get_overall_class_errors(normalized=normalized).items()
},
indent=4,
)

# Save the mean average results
save_json(
mean_average_save_path,
matrix.get_mean_average_values().to_tuple(),
matrix.get_mean_average_values(normalized=normalized).to_tuple(),
indent=4,
)

Expand Down
3 changes: 3 additions & 0 deletions bucketed_scene_flow_eval/eval/threeway_epe.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ def __init__(
bucket_edges = [0.0, dynamic_threshold_meters_per_frame, np.inf]
self.speed_thresholds = list(zip(bucket_edges, bucket_edges[1:]))

def _save_stats_tables(self, average_stats):
super()._save_stats_tables(average_stats, normalized=False)

def compute_results(
self, save_results: bool = True, return_distance_threshold: int = 35
) -> dict[str, tuple[float, float]]:
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ exclude = [

[project]
name = "bucketed_scene_flow_eval"
version = "1.0.2"
version = "1.0.3"
authors = [
{ name="Kyle Vedder", email="[email protected]" },
]
Expand Down
10 changes: 10 additions & 0 deletions scripts/evals/av2_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,9 @@ def run_eval(
est_flow_dir: Path,
output_path: Path,
cpu_count: int,
cache_root: Path,
every_kth: int = 5,
eval_type: str = "bucketed_epe",
) -> None:
assert data_dir.exists(), f"Data directory {data_dir} does not exist."
assert gt_flow_dir.exists(), f"GT flow directory {gt_flow_dir} does not exist."
Expand All @@ -97,7 +99,9 @@ def run_eval(
with_ground=False,
with_rgb=False,
use_gt_flow=True,
eval_type=eval_type,
eval_args=dict(output_path=output_path),
cache_root=cache_root,
)

est_dataset = Argoverse2SceneFlow(
Expand All @@ -107,6 +111,8 @@ def run_eval(
with_rgb=False,
use_gt_flow=False,
use_cache=False,
eval_type=eval_type,
cache_root=cache_root,
)

dataset_evaluator = gt_dataset.evaluator()
Expand Down Expand Up @@ -155,6 +161,8 @@ def run_eval(
parser.add_argument(
"--every_kth", type=int, default=5, help="Only evaluate every kth scene in a sequence"
)
parser.add_argument("--eval_type", type=str, default="bucketed_epe", help="Type of evaluation")
parser.add_argument("--cache_root", type=Path, default=Path("/tmp/av2_eval_cache/"), help="Path to the cache root directory")

args = parser.parse_args()

Expand All @@ -165,4 +173,6 @@ def run_eval(
output_path=args.output_path,
cpu_count=args.cpu_count,
every_kth=args.every_kth,
eval_type=args.eval_type,
cache_root=args.cache_root,
)

0 comments on commit a86047d

Please sign in to comment.