Skip to content

Commit

Permalink
Merge pull request #79 from hmorimitsu/v04
Browse files Browse the repository at this point in the history
V04
  • Loading branch information
hmorimitsu authored Dec 4, 2024
2 parents 9f20eec + be833e9 commit 069b237
Show file tree
Hide file tree
Showing 270 changed files with 11,551 additions and 6,064 deletions.
10 changes: 5 additions & 5 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ jobs:
strategy:
fail-fast: false
matrix:
python-version: ['3.10']
python-version: ['3.12']

steps:
- uses: actions/checkout@v4
Expand All @@ -29,10 +29,10 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install build==1.0.3
python -m pip install --upgrade setuptools==68.0.0 wheel
python -m pip install --upgrade pytest
pip install torch==2.1.2 torchvision==0.16.2 --index-url https://download.pytorch.org/whl/cpu
python -m pip install build==1.2.2.post1
python -m pip install --upgrade setuptools==75.6.0 wheel==0.45.1
python -m pip install --upgrade pytest==8.3.3
pip install torch==2.5.1 torchvision==0.20.1 --index-url https://download.pytorch.org/whl/cpu
- name: Install package and remove local dir
run: |
python -m build
Expand Down
10 changes: 5 additions & 5 deletions .github/workflows/lightning.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,23 +16,23 @@ jobs:
strategy:
fail-fast: false
matrix:
lightning: [1.9.5]
lightning: ["2.1.4", "2.2.5", "2.3.3", "2.4.0"]

steps:
- uses: actions/checkout@v4
- name: Replace lightning
uses: jacobtomlinson/gha-find-replace@v3
with:
find: "lightning<2"
replace: "lightning==${{ matrix.lightning }}"
find: "lightning[pytorch-extra]>=2,<2.5"
replace: "lightning[pytorch-extra]==${{ matrix.lightning }}"
regex: false
include: "requirements.txt"
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install torch==2.1.2 torchvision==0.16.2 --index-url https://download.pytorch.org/whl/cpu
pip install torch==2.5.1 torchvision==0.20.1 --index-url https://download.pytorch.org/whl/cpu
pip install -r requirements.txt
- name: Test with pytest
run: |
pip install pytest
pip install pytest==8.3.3
python -m pytest tests/
4 changes: 2 additions & 2 deletions .github/workflows/publish_pypi.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install torch==2.1.2 torchvision==0.16.2 --index-url https://download.pytorch.org/whl/cpu
pip install build==1.0.3
pip install torch==2.5.1 torchvision==0.20.1 --index-url https://download.pytorch.org/whl/cpu
pip install build==1.2.2.post1
- name: Build package
run: python -m build
- name: Publish package
Expand Down
6 changes: 3 additions & 3 deletions .github/workflows/python.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ jobs:
strategy:
fail-fast: false
matrix:
python-version: ['3.8', '3.9', '3.10', '3.11']
python-version: ['3.10', '3.11', '3.12']

steps:
- uses: actions/checkout@v4
Expand All @@ -27,9 +27,9 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install torch==2.1.2 torchvision==0.16.2 --index-url https://download.pytorch.org/whl/cpu
pip install torch==2.5.1 torchvision==0.20.1 --index-url https://download.pytorch.org/whl/cpu
pip install -r requirements.txt
- name: Test with pytest
run: |
pip install pytest
pip install pytest==8.3.3
python -m pytest tests/
9 changes: 5 additions & 4 deletions .github/workflows/pytorch.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@ jobs:
fail-fast: false
matrix:
pytorch: [
'torch==2.1.2 torchvision==0.16.2 --index-url https://download.pytorch.org/whl/cpu',
'torch==2.0.1 torchvision==0.15.2 --index-url https://download.pytorch.org/whl/cpu',
'torch==1.13.1+cpu torchvision==0.14.1+cpu --extra-index-url https://download.pytorch.org/whl/cpu',
'torch==2.5.1 torchvision==0.20.1 --index-url https://download.pytorch.org/whl/cpu',
'torch==2.4.1 torchvision==0.19.1 --index-url https://download.pytorch.org/whl/cpu',
'torch==2.3.1 torchvision==0.18.1 --index-url https://download.pytorch.org/whl/cpu',
'torch==2.2.2 torchvision==0.17.2 --index-url https://download.pytorch.org/whl/cpu',
]

steps:
Expand All @@ -31,5 +32,5 @@ jobs:
pip install -r requirements.txt
- name: Test with pytest
run: |
pip install pytest
pip install pytest==8.3.3
python -m pytest tests/
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
lightning_logs/
ptlflow_logs/
ptlflow_scripts/
outputs/
ckpts/

# Byte-compiled / optimized / DLL files
__pycache__/
Expand Down
23 changes: 21 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,20 @@ This is still under development, so some things may not work as intended. I plan

## What's new

### - v0.4.0

Major update to support Lightning 2 (finally!). However, it also introduces breaking changes from the previous v0.3 code. See the details below.

- Transitioning from v0.3 to v0.4: check the [v0.4 upgrade guide](https://ptlflow.readthedocs.io/en/latest/starting/v04_upgrade_guide.html)
- Added features:
- Support for YAML config files. See the [config file documentation](https://ptlflow.readthedocs.io/en/latest/starting/config_files.html)
- Table [comparing PTLFlow results with the original papers](https://ptlflow.readthedocs.io/en/latest/results/paper_ptlflow.html) to check the stability of the included models.
- Added new models:
- NeuFlow v2 [https://arxiv.org/abs/2408.10161](https://arxiv.org/abs/2408.10161)
- Add support for more datasets:
- Middlebury-ST [https://vision.middlebury.edu/stereo/data/scenes2014/]{https://vision.middlebury.edu/stereo/data/scenes2014/}
- VIPER [https://playing-for-benchmarks.org/](https://playing-for-benchmarks.org/)

### - v0.3.2

- Added new models:
Expand Down Expand Up @@ -129,8 +143,13 @@ Please take a look at the [documentation](https://ptlflow.readthedocs.io/) to le

You can also check the notebooks below running on Google Colab for some practical examples:

- [Inference with a pretrained model](https://colab.research.google.com/drive/1YARBRUGplqTRnRuY9sKIs6LY_2kWAWZJ?usp=sharing).
- [Training and using the learned weights for inference](https://colab.research.google.com/drive/1mbuAEF728_jZpFEsQHXDxjIGAcB1-nVs?usp=sharing).
- [Inference with a pretrained model](https://colab.research.google.com/drive/1_WXvIRweQJgex0X-HS0LFXBb0IWZIvR4?usp=sharing).
- [Training and using the learned weights for inference](https://colab.research.google.com/drive/1b_SMGSXh9F9TkinqZt0c64EH-GE87HVi?usp=sharing).

If you are using the previous v0.3.X code, then check the [v0.3.2 documentation](https://ptlflow.readthedocs.io/en/v0.3.2/) and the following example notebooks:

- [Inference with a pretrained model (PTLFlow v0.3)](https://colab.research.google.com/drive/1YARBRUGplqTRnRuY9sKIs6LY_2kWAWZJ?usp=sharing).
- [Training and using the learned weights for inference (PTLFlow v0.3)](https://colab.research.google.com/drive/1mbuAEF728_jZpFEsQHXDxjIGAcB1-nVs?usp=sharing).

## Licenses

Expand Down
117 changes: 117 additions & 0 deletions compare_paper_results.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
"""Create a side-by-side table comparing the results of PTLFlow with those reported in the original papers.
This script only evaluates results of models that provide the "things" pretrained models.
Tha parsing of this script is tightly connected to how the results are output by validate.py.
"""

# =============================================================================
# Copyright 2024 Henrique Morimitsu
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================

import argparse
import math
from pathlib import Path

from loguru import logger
import pandas as pd

PAPER_VAL_COLS = {
"model": ("Model", "model"),
"sclean": ("S.clean", "sintel-clean-val/epe"),
"sfinal": ("S.final", "sintel-final-val/epe"),
"k15epe": ("K15-epe", "kitti-2015-val/epe"),
"k15fl": ("K15-fl", "kitti-2015-val/flall"),
}


def _init_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
parser.add_argument(
"--paper_results_path",
type=str,
default=str(Path("docs/source/results/paper_results_things.csv")),
help=("Path to the csv file containing the results from the papers."),
)
parser.add_argument(
"--validate_results_path",
type=str,
default=str(Path("docs/source/results/metrics_all_things.csv")),
help=(
"Path to the csv file containing the results obtained by the validate script."
),
)
parser.add_argument(
"--output_dir",
type=str,
default=str(Path("outputs/metrics")),
help=("Path to the directory where the outputs will be saved."),
)
parser.add_argument(
"--add_delta",
action="store_true",
help=(
"If set, adds one more column showing the difference between paper and validation results."
),
)

return parser


def save_results(args: argparse.Namespace) -> None:
paper_df = pd.read_csv(args.paper_results_path)
val_df = pd.read_csv(args.validate_results_path)
paper_df["model"] = paper_df[PAPER_VAL_COLS["model"][0]]
val_df["model"] = val_df[PAPER_VAL_COLS["model"][1]]
df = pd.merge(val_df, paper_df, "left", "model")

compare_cols = ["ptlflow", "paper"]
if args.add_delta:
compare_cols.append("delta")

out_dict = {"model": ["", ""]}
for name in list(PAPER_VAL_COLS.keys())[1:]:
for ic, col in enumerate(compare_cols):
out_dict[f"{name}-{col}"] = [name if ic == 0 else "", col]

for _, row in df.iterrows():
out_dict["model"].append(row["model"])
for key in list(PAPER_VAL_COLS.keys())[1:]:
paper_col_name = PAPER_VAL_COLS[key][0]
paper_res = float(row[paper_col_name])
val_col_name = PAPER_VAL_COLS[key][1]
val_res = float(row[val_col_name])
res_list = [val_res, paper_res]

if args.add_delta:
delta = val_res - paper_res
res_list.append(delta)

for name, res in zip(compare_cols, res_list):
out_dict[f"{key}-{name}"].append(
"" if (math.isinf(res) or math.isnan(res)) else f"{res:.3f}"
)

out_df = pd.DataFrame(out_dict)
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
output_path = Path(args.output_dir) / "paper_ptlflow_metrics.csv"
out_df.to_csv(output_path, index=False, header=False)
logger.info("Results saved to: {}", output_path)


if __name__ == "__main__":
parser = _init_parser()
args = parser.parse_args()
save_results(args)
24 changes: 24 additions & 0 deletions configs/results/model_benchmark_all.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# lightning.pytorch==2.4.0
# Use this config to benchmark all the models.
# python validate.py --config configs/results/model_benchmark_all.yaml
all: true
select: null
ckpt_path: null
exclude: null
csv_path: null
num_trials: 1
num_samples: 10
sleep_interval: 0.0
input_size:
- 500
- 1000
output_path: outputs/benchmark
final_speed_mode: median
final_memory_mode: first
plot_axes: null
plot_log_x: false
plot_log_y: false
datatypes:
- fp32
batch_size: 1
seed_everything: true
44 changes: 44 additions & 0 deletions configs/results/validate_all.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# lightning.pytorch==2.4.0
# Use this config to generate validation results for all models using all their pretrained ckpts.
# python validate.py --config configs/results/validate_all.yaml
all: true
select: null
exclude: null
ckpt_path: things
output_path: outputs/validate
write_outputs: false
show: false
flow_format: original
max_forward_side: null
scale_factor: null
max_show_side: 1000
max_samples: null
reversed: false
fp16: false
seq_val_mode: all
write_individual_metrics: false
epe_clip: 5.0
seed_everything: true
data:
predict_dataset: null
test_dataset: null
train_dataset: null
val_dataset: sintel-clean+sintel-final+kitti-2015
train_batch_size: null
train_num_workers: 4
train_crop_size: null
train_transform_cuda: false
train_transform_fp16: false
autoflow_root_dir: null
flying_chairs_root_dir: null
flying_chairs2_root_dir: null
flying_things3d_root_dir: null
flying_things3d_subset_root_dir: null
mpi_sintel_root_dir: null
kitti_2012_root_dir: null
kitti_2015_root_dir: null
hd1k_root_dir: null
tartanair_root_dir: null
spring_root_dir: null
kubric_root_dir: null
dataset_config_path: ./datasets.yaml
44 changes: 44 additions & 0 deletions configs/results/validate_all_things.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# lightning.pytorch==2.4.0
# Use this config to generate validation results for all models using their "things" pretrained ckpt.
# python validate.py --config configs/results/validate_all_things.yaml
all: true
select: null
exclude: null
ckpt_path: things
output_path: outputs/validate
write_outputs: false
show: false
flow_format: original
max_forward_side: null
scale_factor: null
max_show_side: 1000
max_samples: null
reversed: false
fp16: false
seq_val_mode: all
write_individual_metrics: false
epe_clip: 5.0
seed_everything: true
data:
predict_dataset: null
test_dataset: null
train_dataset: null
val_dataset: sintel-clean+sintel-final+kitti-2015
train_batch_size: null
train_num_workers: 4
train_crop_size: null
train_transform_cuda: false
train_transform_fp16: false
autoflow_root_dir: null
flying_chairs_root_dir: null
flying_chairs2_root_dir: null
flying_things3d_root_dir: null
flying_things3d_subset_root_dir: null
mpi_sintel_root_dir: null
kitti_2012_root_dir: null
kitti_2015_root_dir: null
hd1k_root_dir: null
tartanair_root_dir: null
spring_root_dir: null
kubric_root_dir: null
dataset_config_path: ./datasets.yaml
File renamed without changes.
8 changes: 4 additions & 4 deletions docs/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
numpydoc==1.6.0
sphinx==7.2.6
sphinx_rtd_theme==2.0.0
numpydoc==1.8.0
sphinx==8.1.3
sphinx_rtd_theme==3.0.2

ptlflow
timm==0.9.9
timm==1.0.11
Loading

0 comments on commit 069b237

Please sign in to comment.