-
Notifications
You must be signed in to change notification settings - Fork 39
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #79 from hmorimitsu/v04
V04
- Loading branch information
Showing
270 changed files
with
11,551 additions
and
6,064 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,117 @@ | ||
"""Create a side-by-side table comparing the results of PTLFlow with those reported in the original papers. | ||
This script only evaluates results of models that provide the "things" pretrained models. | ||
Tha parsing of this script is tightly connected to how the results are output by validate.py. | ||
""" | ||
|
||
# ============================================================================= | ||
# Copyright 2024 Henrique Morimitsu | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================= | ||
|
||
import argparse | ||
import math | ||
from pathlib import Path | ||
|
||
from loguru import logger | ||
import pandas as pd | ||
|
||
PAPER_VAL_COLS = { | ||
"model": ("Model", "model"), | ||
"sclean": ("S.clean", "sintel-clean-val/epe"), | ||
"sfinal": ("S.final", "sintel-final-val/epe"), | ||
"k15epe": ("K15-epe", "kitti-2015-val/epe"), | ||
"k15fl": ("K15-fl", "kitti-2015-val/flall"), | ||
} | ||
|
||
|
||
def _init_parser() -> argparse.ArgumentParser: | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument( | ||
"--paper_results_path", | ||
type=str, | ||
default=str(Path("docs/source/results/paper_results_things.csv")), | ||
help=("Path to the csv file containing the results from the papers."), | ||
) | ||
parser.add_argument( | ||
"--validate_results_path", | ||
type=str, | ||
default=str(Path("docs/source/results/metrics_all_things.csv")), | ||
help=( | ||
"Path to the csv file containing the results obtained by the validate script." | ||
), | ||
) | ||
parser.add_argument( | ||
"--output_dir", | ||
type=str, | ||
default=str(Path("outputs/metrics")), | ||
help=("Path to the directory where the outputs will be saved."), | ||
) | ||
parser.add_argument( | ||
"--add_delta", | ||
action="store_true", | ||
help=( | ||
"If set, adds one more column showing the difference between paper and validation results." | ||
), | ||
) | ||
|
||
return parser | ||
|
||
|
||
def save_results(args: argparse.Namespace) -> None: | ||
paper_df = pd.read_csv(args.paper_results_path) | ||
val_df = pd.read_csv(args.validate_results_path) | ||
paper_df["model"] = paper_df[PAPER_VAL_COLS["model"][0]] | ||
val_df["model"] = val_df[PAPER_VAL_COLS["model"][1]] | ||
df = pd.merge(val_df, paper_df, "left", "model") | ||
|
||
compare_cols = ["ptlflow", "paper"] | ||
if args.add_delta: | ||
compare_cols.append("delta") | ||
|
||
out_dict = {"model": ["", ""]} | ||
for name in list(PAPER_VAL_COLS.keys())[1:]: | ||
for ic, col in enumerate(compare_cols): | ||
out_dict[f"{name}-{col}"] = [name if ic == 0 else "", col] | ||
|
||
for _, row in df.iterrows(): | ||
out_dict["model"].append(row["model"]) | ||
for key in list(PAPER_VAL_COLS.keys())[1:]: | ||
paper_col_name = PAPER_VAL_COLS[key][0] | ||
paper_res = float(row[paper_col_name]) | ||
val_col_name = PAPER_VAL_COLS[key][1] | ||
val_res = float(row[val_col_name]) | ||
res_list = [val_res, paper_res] | ||
|
||
if args.add_delta: | ||
delta = val_res - paper_res | ||
res_list.append(delta) | ||
|
||
for name, res in zip(compare_cols, res_list): | ||
out_dict[f"{key}-{name}"].append( | ||
"" if (math.isinf(res) or math.isnan(res)) else f"{res:.3f}" | ||
) | ||
|
||
out_df = pd.DataFrame(out_dict) | ||
Path(args.output_dir).mkdir(parents=True, exist_ok=True) | ||
output_path = Path(args.output_dir) / "paper_ptlflow_metrics.csv" | ||
out_df.to_csv(output_path, index=False, header=False) | ||
logger.info("Results saved to: {}", output_path) | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = _init_parser() | ||
args = parser.parse_args() | ||
save_results(args) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
# lightning.pytorch==2.4.0 | ||
# Use this config to benchmark all the models. | ||
# python validate.py --config configs/results/model_benchmark_all.yaml | ||
all: true | ||
select: null | ||
ckpt_path: null | ||
exclude: null | ||
csv_path: null | ||
num_trials: 1 | ||
num_samples: 10 | ||
sleep_interval: 0.0 | ||
input_size: | ||
- 500 | ||
- 1000 | ||
output_path: outputs/benchmark | ||
final_speed_mode: median | ||
final_memory_mode: first | ||
plot_axes: null | ||
plot_log_x: false | ||
plot_log_y: false | ||
datatypes: | ||
- fp32 | ||
batch_size: 1 | ||
seed_everything: true |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
# lightning.pytorch==2.4.0 | ||
# Use this config to generate validation results for all models using all their pretrained ckpts. | ||
# python validate.py --config configs/results/validate_all.yaml | ||
all: true | ||
select: null | ||
exclude: null | ||
ckpt_path: things | ||
output_path: outputs/validate | ||
write_outputs: false | ||
show: false | ||
flow_format: original | ||
max_forward_side: null | ||
scale_factor: null | ||
max_show_side: 1000 | ||
max_samples: null | ||
reversed: false | ||
fp16: false | ||
seq_val_mode: all | ||
write_individual_metrics: false | ||
epe_clip: 5.0 | ||
seed_everything: true | ||
data: | ||
predict_dataset: null | ||
test_dataset: null | ||
train_dataset: null | ||
val_dataset: sintel-clean+sintel-final+kitti-2015 | ||
train_batch_size: null | ||
train_num_workers: 4 | ||
train_crop_size: null | ||
train_transform_cuda: false | ||
train_transform_fp16: false | ||
autoflow_root_dir: null | ||
flying_chairs_root_dir: null | ||
flying_chairs2_root_dir: null | ||
flying_things3d_root_dir: null | ||
flying_things3d_subset_root_dir: null | ||
mpi_sintel_root_dir: null | ||
kitti_2012_root_dir: null | ||
kitti_2015_root_dir: null | ||
hd1k_root_dir: null | ||
tartanair_root_dir: null | ||
spring_root_dir: null | ||
kubric_root_dir: null | ||
dataset_config_path: ./datasets.yaml |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
# lightning.pytorch==2.4.0 | ||
# Use this config to generate validation results for all models using their "things" pretrained ckpt. | ||
# python validate.py --config configs/results/validate_all_things.yaml | ||
all: true | ||
select: null | ||
exclude: null | ||
ckpt_path: things | ||
output_path: outputs/validate | ||
write_outputs: false | ||
show: false | ||
flow_format: original | ||
max_forward_side: null | ||
scale_factor: null | ||
max_show_side: 1000 | ||
max_samples: null | ||
reversed: false | ||
fp16: false | ||
seq_val_mode: all | ||
write_individual_metrics: false | ||
epe_clip: 5.0 | ||
seed_everything: true | ||
data: | ||
predict_dataset: null | ||
test_dataset: null | ||
train_dataset: null | ||
val_dataset: sintel-clean+sintel-final+kitti-2015 | ||
train_batch_size: null | ||
train_num_workers: 4 | ||
train_crop_size: null | ||
train_transform_cuda: false | ||
train_transform_fp16: false | ||
autoflow_root_dir: null | ||
flying_chairs_root_dir: null | ||
flying_chairs2_root_dir: null | ||
flying_things3d_root_dir: null | ||
flying_things3d_subset_root_dir: null | ||
mpi_sintel_root_dir: null | ||
kitti_2012_root_dir: null | ||
kitti_2015_root_dir: null | ||
hd1k_root_dir: null | ||
tartanair_root_dir: null | ||
spring_root_dir: null | ||
kubric_root_dir: null | ||
dataset_config_path: ./datasets.yaml |
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,6 @@ | ||
numpydoc==1.6.0 | ||
sphinx==7.2.6 | ||
sphinx_rtd_theme==2.0.0 | ||
numpydoc==1.8.0 | ||
sphinx==8.1.3 | ||
sphinx_rtd_theme==3.0.2 | ||
|
||
ptlflow | ||
timm==0.9.9 | ||
timm==1.0.11 |
File renamed without changes.
File renamed without changes.
File renamed without changes.
Oops, something went wrong.