Skip to content

Commit

Permalink
Add HO3D-v2 training and evaluation
Browse files Browse the repository at this point in the history
  • Loading branch information
hassony2 committed Sep 28, 2020
1 parent 056697a commit 00c5f77
Show file tree
Hide file tree
Showing 8 changed files with 272 additions and 6 deletions.
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,7 @@ tmp*
*.sw*
*.png
*.pdf

jsonres/
pred.json
pred.zip
29 changes: 26 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ handobjectconsist/
`tar -xvf assets/fhbhands_fits.tgz -C assets/`
- Download [pre-trained models](https://github.com/hassony2/handobjectconsist/releases/download/v0.1/releasemodels.zip)
- Download [pre-trained models](https://github.com/hassony2/handobjectconsist/releases/download/v0.3/releasemodels.zip)
`wget https://github.com/hassony2/handobjectconsist/releases/download/v0.2/releasemodels.zip`
Expand Down Expand Up @@ -94,10 +94,33 @@ releasemodels/
### HO3D
*Optional*: Download the [HO3D-v2](https://files.icg.tugraz.at/d/76661ed06445490ab21c/) dataset.
Note that all results in our paper are reported on a **subset** of the current dataset which was published as an [early release](https://arxiv.org/abs/1907.01481v1).
#### CVPR 2020
Note that all results in our paper are reported on a **subset** of the current dataset which was published as an [early release](https://arxiv.org/abs/1907.01481v1), additionally we used synthetic data which is not released.
The results are therefore *not directly comparable* with the [final published results](https://arxiv.org/abs/1907.01481) which are reported on the v2 version of the dataset.
#### Codalab challenge pre-trained model
After submisison I retrained a baseline model on the current dataset (official release of HO3D, which I refer to as HO3D-v2). You can get the model from the releasemodels
Evaluate the pre-trained model:
- Download [pre-trained models](https://github.com/hassony2/handobjectconsist/releases/download/v0.3/releasemodels.zip)
- Extract the pre-trained models `unzip releasemodels.zip`
- Run the evaluation code and generate the codalab submission file
`python evalho3dv2.py --resume releasemodels/ho3dv2/realonly/checkpoint_200.pth --val_split test --json_folder jsonres/res`
This will create a file 'pred.zip' ready for upload to the [codalab challenge](https://competitions.codalab.org/competitions/22485)
#### Training model on HO3D-v2
- Download the [HO3D-v2](https://files.icg.tugraz.at/d/76661ed06445490ab21c/) dataset.
- launch training using `python trainmeshreg` and providing all arguments as in `releasemodels/ho3dv2/realonly/opt.txt`
# Demo
Run the demo on the FPHAB dataset.
Expand Down
2 changes: 2 additions & 0 deletions environment.yml
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
name: handobject_env
channels:
- open3d-admin
- menpo
- pytorch
- defaults
Expand All @@ -9,6 +10,7 @@ dependencies:
- cupy
- joblib
- opencv
- open3d=0.9
- pandas
- python=3.7
- pyyaml
Expand Down
156 changes: 156 additions & 0 deletions evalho3dv2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
import argparse
from datetime import datetime
import os
import random

from matplotlib import pyplot as plt
import numpy as np
import torch

from libyana.exputils.argutils import save_args
from libyana.modelutils import freeze

from meshreg.datasets import collate
from meshreg.netscripts import evalpass, reloadmodel, get_dataset


plt.switch_backend("agg")


def main(args):
torch.cuda.manual_seed_all(args.manual_seed)
torch.manual_seed(args.manual_seed)
np.random.seed(args.manual_seed)
random.seed(args.manual_seed)
# Initialize hosting
dat_str = args.val_dataset
now = datetime.now()
exp_id = (
f"checkpoints/{dat_str}_mini{args.mini_factor}/"
f"{now.year}_{now.month:02d}_{now.day:02d}/"
f"{args.com}_frac{args.fraction}_mode{args.mode}_bs{args.batch_size}_"
f"objs{args.obj_scale_factor}_objt{args.obj_trans_factor}"
)

# Initialize local checkpoint folder
save_args(args, exp_id, "opt")
result_folder = os.path.join(exp_id, "results")
os.makedirs(result_folder, exist_ok=True)
pyapt_path = os.path.join(result_folder, f"{args.pyapt_id}__{now.strftime('%H_%M_%S')}")
with open(pyapt_path, "a") as t_f:
t_f.write(" ")

val_dataset, input_size = get_dataset.get_dataset(
args.val_dataset,
split=args.val_split,
meta={"version": args.version, "split_mode": "paper"},
use_cache=args.use_cache,
mini_factor=args.mini_factor,
mode=args.mode,
fraction=args.fraction,
no_augm=True,
center_idx=args.center_idx,
scale_jittering=0,
center_jittering=0,
sample_nb=None,
has_dist2strong=True,
)
val_loader = torch.utils.data.DataLoader(
val_dataset,
batch_size=args.batch_size,
shuffle=False,
num_workers=int(args.workers),
drop_last=False,
collate_fn=collate.meshreg_collate,
)

opts = reloadmodel.load_opts(args.resume)
model, epoch = reloadmodel.reload_model(args.resume, opts)
if args.render_results:
render_folder = os.path.join(exp_id, f"renders", f"epoch{epoch:04d}")
os.makedirs(render_folder, exist_ok=True)
print(f"Rendering to {render_folder}")
else:
render_folder = None
img_folder = os.path.join(exp_id, "images", f"epoch{epoch:04d}")
os.makedirs(img_folder, exist_ok=True)
freeze.freeze_batchnorm_stats(model) # Freeze batchnorm

fig = plt.figure(figsize=(12, 4))
save_results = {}
save_results["opt"] = dict(vars(args))
save_results["val_losses"] = []
os.makedirs(args.json_folder, exist_ok=True)
json_path = os.path.join(args.json_folder, f"{args.val_split}.json")
evalpass.epoch_pass(
val_loader,
model,
optimizer=None,
scheduler=None,
epoch=epoch,
img_folder=img_folder,
fig=fig,
display_freq=args.display_freq,
dump_results_path=json_path,
render_folder=render_folder,
render_freq=args.render_freq,
true_root=args.true_root,
)
print(f"Saved results for split {args.val_split} to {json_path}")


if __name__ == "__main__":
torch.multiprocessing.set_sharing_strategy("file_system")
# torch.multiprocessing.set_start_method("forkserver")
parser = argparse.ArgumentParser()
parser.add_argument("--com", default="debug/")

# Dataset params
parser.add_argument("--val_dataset", choices=["ho3dv2"], default="ho3dv2")
parser.add_argument("--val_split", default="val")
parser.add_argument("--mini_factor", type=float, default=1)
parser.add_argument("--max_verts", type=int, default=1000)
parser.add_argument("--use_cache", action="store_true")
parser.add_argument("--synth", action="store_true")
parser.add_argument("--version", default=3, type=int)
parser.add_argument("--fraction", type=float, default=1)
parser.add_argument("--mode", choices=["strong", "weak", "full"], default="strong")

# Test options
parser.add_argument("--dump_results", action="store_true")
parser.add_argument("--render_results", action="store_true")
parser.add_argument("--render_freq", type=int, default=10)

# Model params
parser.add_argument("--center_idx", default=9, type=int)
parser.add_argument(
"--true_root", action="store_true", help="Replace predicted wrist position with ground truth root"
)
parser.add_argument("--resume")

# Training params
parser.add_argument("--manual_seed", type=int, default=0)
parser.add_argument("--batch_size", type=int, default=4)
parser.add_argument("--workers", type=int, default=8)
parser.add_argument("--epochs", type=int, default=10000)
parser.add_argument("--freeze_batchnorm", action="store_true")
parser.add_argument("--pyapt_id")
parser.add_argument("--criterion2d", choices=["l2", "l1", "smooth_l1"], default="l2")

# Weighting
parser.add_argument("--obj_trans_factor", type=float, default=1)
parser.add_argument("--obj_scale_factor", type=float, default=1)

# Evaluation params
parser.add_argument("--mask_threshold", type=float, default=0.9)
parser.add_argument("--json_folder", default="jsonres/res")

# Weighting params
parser.add_argument("--display_freq", type=int, default=100)
parser.add_argument("--snapshot", type=int, default=50)

args = parser.parse_args()
for key, val in sorted(vars(args).items(), key=lambda x: x[0]):
print(f"{key}: {val}")

main(args)
7 changes: 6 additions & 1 deletion meshreg/datasets/ho3dv2utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from collections import defaultdict
import subprocess
import shutil
import json
import os
import pickle
Expand Down Expand Up @@ -123,7 +125,7 @@ def get_objectsplit_infos(seqs, root, subfolder="train", fraction=1):
return all_sequences, seq_map, closeseq_map, strongs, weaks, idxs


def dump(pred_out_path, xyz_pred_list, verts_pred_list):
def dump(pred_out_path, xyz_pred_list, verts_pred_list, codalab=True):
""" Save predictions into a json file for official ho3dv2 evaluation. """
# make sure its only lists
def roundall(rows):
Expand All @@ -139,3 +141,6 @@ def roundall(rows):
"Dumped %d joints and %d verts predictions to %s"
% (len(xyz_pred_list), len(verts_pred_list), pred_out_path)
)
if codalab:
shutil.copy(pred_out_path, "pred.json")
subprocess.call(["zip", "-j", "pred.zip", "pred.json"])
1 change: 0 additions & 1 deletion meshreg/netscripts/epochpass.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

from meshreg.datasets.queries import BaseQueries
from meshreg.netscripts import evaluate
from meshreg.neurender import fastrender
from meshreg.visualize import samplevis, evalvis
from meshreg.visualize import consistdisplay

Expand Down
77 changes: 77 additions & 0 deletions meshreg/netscripts/evalpass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
import os

import numpy as np
from tqdm import tqdm
import torch

from libyana.evalutils.avgmeter import AverageMeters
from libyana.evalutils.zimeval import EvalUtil

from meshreg.visualize import samplevis
from meshreg.netscripts import evaluate
from meshreg.datasets.queries import BaseQueries
from meshreg.datasets import ho3dv2utils


def get_order_idxs():
reorder_idxs = [0, 13, 14, 15, 16, 1, 2, 3, 17, 4, 5, 6, 18, 10, 11, 12, 19, 7, 8, 9, 20]
unorder_idxs = np.argsort(reorder_idxs)
return reorder_idxs, unorder_idxs


def epoch_pass(
loader,
model,
optimizer=None,
scheduler=None,
epoch=0,
img_folder=None,
fig=None,
display_freq=10,
epoch_display_freq=1,
lr_decay_gamma=0,
freeze_batchnorm=True,
dump_results_path=None,
render_folder=None,
render_freq=10,
true_root=False,
):
prefix = "val"
reorder_idxs, unorder_idxs = get_order_idxs()
evaluators = {
# "joints2d_trans": EvalUtil(),
"joints2d_base": EvalUtil(),
"corners2d_base": EvalUtil(),
"verts2d_base": EvalUtil(),
"joints3d_cent": EvalUtil(),
"joints3d": EvalUtil(),
}
model.eval()
model.cuda()
avg_meters = AverageMeters()
all_joints = []
all_verts = []
for batch_idx, batch in enumerate(tqdm(loader)):
with torch.no_grad():
loss, results, losses = model(batch)
# Collect hand joints
if true_root:
results["recov_joints3d"][:, 0] = batch[BaseQueries.JOINTS3D][:, 0]
recov_joints = results["recov_joints3d"].cpu().detach()[:, unorder_idxs]
recov_joints[:, :, 0] = -recov_joints[:, :, 0]
new_joints = [-val.numpy()[0] for val in recov_joints.split(1)]
all_joints.extend(new_joints)

# Collect hand vertices
recov_verts = results["recov_handverts3d"].cpu().detach()
recov_verts[:, :, 0] = -recov_verts[:, :, 0]
new_verts = [-val.numpy()[0] for val in recov_verts.split(1)]
all_verts.extend(new_verts)

evaluate.feed_avg_meters(avg_meters, batch, results)
if batch_idx % display_freq == 0 and epoch % epoch_display_freq == 0:
img_filepath = f"{prefix}_epoch{epoch:04d}_batch{batch_idx:06d}.png"
save_img_path = os.path.join(img_folder, img_filepath)
samplevis.sample_vis(batch, results, fig=fig, save_img_path=save_img_path)
evaluate.feed_evaluators(evaluators, batch, results)
ho3dv2utils.dump(dump_results_path, all_joints, all_verts, codalab=True)
2 changes: 1 addition & 1 deletion trainmeshreg.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ def main(args):
parser.add_argument("--no_augm", action="store_true", help="Prevent all data augmenation")
parser.add_argument("--block_rot", action="store_true", help="Prevent rotation during data augmentation")
parser.add_argument("--max_rot", default=0, type=float, help="Max rotation for data augmentation")
parser.add_argument("--version", default=1, type=int, help="Version of HO3D dataset to use")
parser.add_argument("--version", default=3, type=int, help="Version of synthetic HO3D dataset to use")
parser.add_argument("--center_idx", default=9, type=int)
parser.add_argument(
"--center_jittering", type=float, default=0.1, help="Controls magnitude of center jittering"
Expand Down

0 comments on commit 00c5f77

Please sign in to comment.