-
Notifications
You must be signed in to change notification settings - Fork 8
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
63db824
commit 8032000
Showing
70 changed files
with
6,511 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
# A2J-Transformer | ||
|
||
## Introduction | ||
This is the official implementation for the paper, **"A2J-Transformer: Anchor-to-Joint Transformer Network for 3D Interacting Hand Pose Estimation from a Single RGB Image"**, CVPR 2023. | ||
|
||
# About our code | ||
|
||
|
||
## Installation and Setup | ||
|
||
### Requirements | ||
|
||
* Our code is tested under Ubuntu 20.04 environment with NVIDIA 2080Ti GPU and NVIDIA 3090 GPU, both Pytorch1.7 and Pytorch1.11 work. | ||
|
||
* Python>=3.7 | ||
|
||
We recommend you to use Anaconda to create a conda environment: | ||
```bash | ||
conda create --name a2j_trans python=3.7 | ||
``` | ||
Then, activate the environment: | ||
```bash | ||
conda activate a2j_trans | ||
``` | ||
|
||
* PyTorch>=1.7.1, torchvision>=0.8.2 (following instructions [here](https://pytorch.org/)) | ||
|
||
We recommend you to use the following pytorch and torchvision: | ||
```bash | ||
conda install pytorch==1.7.1 torchvision==0.8.2 torchaudio==0.7.2 cudatoolkit=11.0 -c pytorch | ||
``` | ||
|
||
* Other requirements | ||
```bash | ||
conda install tqdm numpy matplotlib scipy | ||
pip install opencv-python pycocotools | ||
``` | ||
|
||
### Compiling CUDA operators(Following [Deformable-DETR](https://github.com/fundamentalvision/Deformable-DETR)) | ||
```bash | ||
cd ./dab_deformable_detr/ops | ||
sh make.sh | ||
``` | ||
|
||
## Usage | ||
|
||
### Dataset preparation | ||
|
||
* Please download [InterHand 2.6M Dataset](https://mks0601.github.io/InterHand2.6M/) and organize them as following: | ||
|
||
``` | ||
your_dataset_path/ | ||
└── Interhand2.6M_5fps/ | ||
├── annotations/ | ||
└── images/ | ||
``` | ||
|
||
|
||
|
||
### Testing on InterHand 2.6M Dataset | ||
|
||
* Please download our [pre-trained model](https://drive.google.com/file/d/1QKqokPnSkWMRJjZkj04Nhf0eQCl66-6r/view?usp=share_link) and organize the code as following: | ||
|
||
``` | ||
a2j-transformer/ | ||
├── dab_deformable_detr/ | ||
├── nets/ | ||
├── utils/ | ||
├── ...py | ||
├── datalist/ | ||
| └── ...pkl | ||
└── output/ | ||
└── model_dump/ | ||
└── snapshot.pth.tar | ||
``` | ||
The `datalist` folder and the pkl files denotes the dataset-list generated during running the code. | ||
You can choose to download them [here](https://drive.google.com/file/d/1pfghhGnS5wI23UtF3a4IgBbXz-e2hgYI/view?usp=share_link), and manually put them under the `datalist` folder. | ||
|
||
* In `config.py`, set `interhand_anno_dir`, `interhand_images_path` to the dataset abs directory. | ||
* In `config.py`, set `cur_dir` to the a2j-transformer code directory. | ||
* Run the following script: | ||
```python | ||
python test.py --gpu <your_gpu_ids> | ||
``` | ||
You can also choose to change the `gpu_ids` in `test.py`. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
import numpy as np | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
from config import cfg | ||
from nets.layer import make_linear_layers | ||
|
||
|
||
## generate 3D coords of 3D anchors. num is 256*3 | ||
def generate_all_anchors_3d(): | ||
x_center = 7.5 | ||
y_center = 7.5 | ||
stride = 16 | ||
step_h = 16 | ||
step_w = 16 | ||
|
||
d_center = 63.5 | ||
stride_d = 64 | ||
step_d = 3 | ||
|
||
anchors_h = np.arange(0,step_h) * stride + x_center | ||
anchors_w = np.arange(0,step_w) * stride + y_center | ||
anchors_d = np.arange(0,step_d) * stride_d + d_center | ||
anchors_x, anchors_y, anchors_z = np.meshgrid(anchors_h, anchors_w, anchors_d) | ||
all_anchors = np.vstack((anchors_x.ravel(), anchors_y.ravel(), anchors_z.ravel())).transpose() #256*3 | ||
return all_anchors | ||
|
||
|
||
class generate_keypoints_coord_new(nn.Module): | ||
def __init__(self, num_joints, is_3D=True): | ||
super(generate_keypoints_coord_new, self).__init__() | ||
self.is_3D = is_3D | ||
self.num_joints = num_joints | ||
|
||
def forward(self, total_coords, total_weights, total_references): | ||
lvl_num, batch_size, a, _ = total_coords.shape | ||
total_coords = total_coords.reshape(lvl_num, batch_size, a, self.num_joints, -1) ## l,b,a,j,3 | ||
|
||
weights_softmax = F.softmax(total_weights, dim=2) | ||
weights = torch.unsqueeze(weights_softmax, dim=4).expand(-1,-1,-1,-1, 3) ## l,b,a,j,3 | ||
|
||
keypoints = torch.unsqueeze(total_references, dim = 3).expand(-1,-1,-1,42,-1) + total_coords | ||
pred_keypoints = (keypoints * weights).sum(2) ## l,b,a,3 | ||
anchors = (torch.unsqueeze(total_references, dim = 3) * weights).sum(2) | ||
return pred_keypoints, anchors |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,159 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
# | ||
|
||
import os | ||
import os.path as osp | ||
import math | ||
import time | ||
import glob | ||
import abc | ||
from torch.utils.data import DataLoader | ||
import torch.optim | ||
import torchvision.transforms as transforms | ||
|
||
from config import cfg | ||
from dataset import Dataset | ||
from timer import Timer | ||
from logger import colorlogger | ||
from torch.nn.parallel.data_parallel import DataParallel | ||
from model import get_model | ||
|
||
class Base(object): | ||
__metaclass__ = abc.ABCMeta | ||
|
||
def __init__(self, log_name='logs.txt'): | ||
|
||
self.cur_epoch = 0 | ||
|
||
# timer | ||
self.tot_timer = Timer() | ||
self.gpu_timer = Timer() | ||
self.read_timer = Timer() | ||
|
||
# logger | ||
self.logger = colorlogger(cfg.log_dir, log_name=log_name) | ||
|
||
@abc.abstractmethod | ||
def _make_batch_generator(self): | ||
return | ||
|
||
@abc.abstractmethod | ||
def _make_model(self): | ||
return | ||
|
||
|
||
class Trainer(Base): | ||
|
||
def __init__(self): | ||
super(Trainer, self).__init__(log_name = 'train_logs.txt') | ||
|
||
def get_optimizer(self, model): | ||
optimizer = torch.optim.Adam(model.parameters(), lr=cfg.lr) | ||
return optimizer | ||
|
||
def set_lr(self, epoch): | ||
if len(cfg.lr_dec_epoch) == 0: | ||
return cfg.lr | ||
|
||
for e in cfg.lr_dec_epoch: | ||
if epoch < e: | ||
break | ||
if epoch < cfg.lr_dec_epoch[-1]: | ||
idx = cfg.lr_dec_epoch.index(e) | ||
for g in self.optimizer.param_groups: | ||
g['lr'] = cfg.lr / (cfg.lr_dec_factor ** idx) | ||
else: | ||
for g in self.optimizer.param_groups: | ||
g['lr'] = cfg.lr / (cfg.lr_dec_factor ** len(cfg.lr_dec_epoch)) | ||
|
||
def get_lr(self): | ||
for g in self.optimizer.param_groups: | ||
cur_lr = g['lr'] | ||
|
||
return cur_lr | ||
def _make_batch_generator(self): | ||
# data load and construct batch generator | ||
self.logger.info("Creating train dataset...") | ||
trainset_loader = Dataset(transforms.ToTensor(), "train") | ||
batch_generator = DataLoader(dataset=trainset_loader, batch_size=cfg.num_gpus*cfg.train_batch_size, shuffle=True, | ||
num_workers=cfg.num_thread, pin_memory=True, drop_last=True) | ||
|
||
self.joint_num = trainset_loader.joint_num | ||
self.itr_per_epoch = math.ceil(trainset_loader.__len__() / cfg.num_gpus / cfg.train_batch_size) | ||
self.batch_generator = batch_generator | ||
|
||
def _make_model(self): | ||
# prepare network | ||
self.logger.info("Creating graph and optimizer...") | ||
model = get_model('train', self.joint_num) | ||
model = DataParallel(model).cuda() | ||
optimizer = self.get_optimizer(model) | ||
if cfg.continue_train: | ||
start_epoch, model, optimizer = self.load_model(model, optimizer) | ||
else: | ||
start_epoch = 0 | ||
model.train() | ||
|
||
self.start_epoch = start_epoch | ||
self.model = model | ||
self.optimizer = optimizer | ||
|
||
def save_model(self, state, epoch): | ||
file_path = osp.join(cfg.model_dir,'snapshot_{}.pth.tar'.format(str(epoch))) | ||
torch.save(state, file_path) | ||
self.logger.info("Write snapshot into {}".format(file_path)) | ||
|
||
def load_model(self, model, optimizer): | ||
model_file_list = glob.glob(osp.join(cfg.model_dir,'*.pth.tar')) | ||
cur_epoch = max([int(file_name[file_name.find('snapshot_') + 9 : file_name.find('.pth.tar')]) for file_name in model_file_list]) | ||
model_path = osp.join(cfg.model_dir, 'snapshot_' + str(cur_epoch) + '.pth.tar') | ||
self.logger.info('Load checkpoint from {}'.format(model_path)) | ||
ckpt = torch.load(model_path) | ||
start_epoch = ckpt['epoch'] + 1 | ||
|
||
model.load_state_dict(ckpt['network']) | ||
try: | ||
optimizer.load_state_dict(ckpt['optimizer']) | ||
except: | ||
pass | ||
|
||
return start_epoch, model, optimizer | ||
|
||
|
||
class Tester(Base): | ||
|
||
def __init__(self): | ||
super(Tester, self).__init__(log_name = 'test_logs.txt') | ||
|
||
def _make_batch_generator(self, test_set): | ||
# data load and construct batch generator | ||
self.logger.info("Creating " + test_set + " dataset...") | ||
testset_loader = Dataset(transforms.ToTensor(), test_set) | ||
batch_generator = DataLoader(dataset=testset_loader, batch_size=cfg.num_gpus*cfg.test_batch_size, shuffle=False, num_workers=cfg.num_thread, pin_memory=True) | ||
|
||
self.joint_num = testset_loader.joint_num | ||
self.batch_generator = batch_generator | ||
self.testset = testset_loader | ||
|
||
def _make_model(self): | ||
model_path = os.path.join(cfg.model_dir, 'snapshot.pth.tar') | ||
assert os.path.exists(model_path), 'Cannot find model at ' + model_path | ||
self.logger.info('Load checkpoint from {}'.format(model_path)) | ||
|
||
# prepare network | ||
self.logger.info("Creating graph...") | ||
model = get_model('test', self.joint_num) | ||
model = DataParallel(model).cuda() | ||
ckpt = torch.load(model_path) | ||
model.load_state_dict(ckpt['network']) | ||
model.eval() | ||
|
||
self.model = model | ||
|
||
def _evaluate(self, preds): | ||
mpjpe_dict = self.testset.evaluate(preds) | ||
return mpjpe_dict |
Oops, something went wrong.