Skip to content

Commit

Permalink
code for a2j-transformer
Browse files Browse the repository at this point in the history
  • Loading branch information
ChanglongJiangGit committed Mar 23, 2023
1 parent 63db824 commit 8032000
Show file tree
Hide file tree
Showing 70 changed files with 6,511 additions and 0 deletions.
85 changes: 85 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# A2J-Transformer

## Introduction
This is the official implementation for the paper, **"A2J-Transformer: Anchor-to-Joint Transformer Network for 3D Interacting Hand Pose Estimation from a Single RGB Image"**, CVPR 2023.

# About our code


## Installation and Setup

### Requirements

* Our code is tested under Ubuntu 20.04 environment with NVIDIA 2080Ti GPU and NVIDIA 3090 GPU, both Pytorch1.7 and Pytorch1.11 work.

* Python>=3.7

We recommend you to use Anaconda to create a conda environment:
```bash
conda create --name a2j_trans python=3.7
```
Then, activate the environment:
```bash
conda activate a2j_trans
```

* PyTorch>=1.7.1, torchvision>=0.8.2 (following instructions [here](https://pytorch.org/))

We recommend you to use the following pytorch and torchvision:
```bash
conda install pytorch==1.7.1 torchvision==0.8.2 torchaudio==0.7.2 cudatoolkit=11.0 -c pytorch
```

* Other requirements
```bash
conda install tqdm numpy matplotlib scipy
pip install opencv-python pycocotools
```

### Compiling CUDA operators(Following [Deformable-DETR](https://github.com/fundamentalvision/Deformable-DETR))
```bash
cd ./dab_deformable_detr/ops
sh make.sh
```

## Usage

### Dataset preparation

* Please download [InterHand 2.6M Dataset](https://mks0601.github.io/InterHand2.6M/) and organize them as following:

```
your_dataset_path/
└── Interhand2.6M_5fps/
├── annotations/
└── images/
```



### Testing on InterHand 2.6M Dataset

* Please download our [pre-trained model](https://drive.google.com/file/d/1QKqokPnSkWMRJjZkj04Nhf0eQCl66-6r/view?usp=share_link) and organize the code as following:

```
a2j-transformer/
├── dab_deformable_detr/
├── nets/
├── utils/
├── ...py
├── datalist/
| └── ...pkl
└── output/
└── model_dump/
└── snapshot.pth.tar
```
The `datalist` folder and the pkl files denotes the dataset-list generated during running the code.
You can choose to download them [here](https://drive.google.com/file/d/1pfghhGnS5wI23UtF3a4IgBbXz-e2hgYI/view?usp=share_link), and manually put them under the `datalist` folder.

* In `config.py`, set `interhand_anno_dir`, `interhand_images_path` to the dataset abs directory.
* In `config.py`, set `cur_dir` to the a2j-transformer code directory.
* Run the following script:
```python
python test.py --gpu <your_gpu_ids>
```
You can also choose to change the `gpu_ids` in `test.py`.
45 changes: 45 additions & 0 deletions anchor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from config import cfg
from nets.layer import make_linear_layers


## generate 3D coords of 3D anchors. num is 256*3
def generate_all_anchors_3d():
x_center = 7.5
y_center = 7.5
stride = 16
step_h = 16
step_w = 16

d_center = 63.5
stride_d = 64
step_d = 3

anchors_h = np.arange(0,step_h) * stride + x_center
anchors_w = np.arange(0,step_w) * stride + y_center
anchors_d = np.arange(0,step_d) * stride_d + d_center
anchors_x, anchors_y, anchors_z = np.meshgrid(anchors_h, anchors_w, anchors_d)
all_anchors = np.vstack((anchors_x.ravel(), anchors_y.ravel(), anchors_z.ravel())).transpose() #256*3
return all_anchors


class generate_keypoints_coord_new(nn.Module):
def __init__(self, num_joints, is_3D=True):
super(generate_keypoints_coord_new, self).__init__()
self.is_3D = is_3D
self.num_joints = num_joints

def forward(self, total_coords, total_weights, total_references):
lvl_num, batch_size, a, _ = total_coords.shape
total_coords = total_coords.reshape(lvl_num, batch_size, a, self.num_joints, -1) ## l,b,a,j,3

weights_softmax = F.softmax(total_weights, dim=2)
weights = torch.unsqueeze(weights_softmax, dim=4).expand(-1,-1,-1,-1, 3) ## l,b,a,j,3

keypoints = torch.unsqueeze(total_references, dim = 3).expand(-1,-1,-1,42,-1) + total_coords
pred_keypoints = (keypoints * weights).sum(2) ## l,b,a,3
anchors = (torch.unsqueeze(total_references, dim = 3) * weights).sum(2)
return pred_keypoints, anchors
159 changes: 159 additions & 0 deletions base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#

import os
import os.path as osp
import math
import time
import glob
import abc
from torch.utils.data import DataLoader
import torch.optim
import torchvision.transforms as transforms

from config import cfg
from dataset import Dataset
from timer import Timer
from logger import colorlogger
from torch.nn.parallel.data_parallel import DataParallel
from model import get_model

class Base(object):
__metaclass__ = abc.ABCMeta

def __init__(self, log_name='logs.txt'):

self.cur_epoch = 0

# timer
self.tot_timer = Timer()
self.gpu_timer = Timer()
self.read_timer = Timer()

# logger
self.logger = colorlogger(cfg.log_dir, log_name=log_name)

@abc.abstractmethod
def _make_batch_generator(self):
return

@abc.abstractmethod
def _make_model(self):
return


class Trainer(Base):

def __init__(self):
super(Trainer, self).__init__(log_name = 'train_logs.txt')

def get_optimizer(self, model):
optimizer = torch.optim.Adam(model.parameters(), lr=cfg.lr)
return optimizer

def set_lr(self, epoch):
if len(cfg.lr_dec_epoch) == 0:
return cfg.lr

for e in cfg.lr_dec_epoch:
if epoch < e:
break
if epoch < cfg.lr_dec_epoch[-1]:
idx = cfg.lr_dec_epoch.index(e)
for g in self.optimizer.param_groups:
g['lr'] = cfg.lr / (cfg.lr_dec_factor ** idx)
else:
for g in self.optimizer.param_groups:
g['lr'] = cfg.lr / (cfg.lr_dec_factor ** len(cfg.lr_dec_epoch))

def get_lr(self):
for g in self.optimizer.param_groups:
cur_lr = g['lr']

return cur_lr
def _make_batch_generator(self):
# data load and construct batch generator
self.logger.info("Creating train dataset...")
trainset_loader = Dataset(transforms.ToTensor(), "train")
batch_generator = DataLoader(dataset=trainset_loader, batch_size=cfg.num_gpus*cfg.train_batch_size, shuffle=True,
num_workers=cfg.num_thread, pin_memory=True, drop_last=True)

self.joint_num = trainset_loader.joint_num
self.itr_per_epoch = math.ceil(trainset_loader.__len__() / cfg.num_gpus / cfg.train_batch_size)
self.batch_generator = batch_generator

def _make_model(self):
# prepare network
self.logger.info("Creating graph and optimizer...")
model = get_model('train', self.joint_num)
model = DataParallel(model).cuda()
optimizer = self.get_optimizer(model)
if cfg.continue_train:
start_epoch, model, optimizer = self.load_model(model, optimizer)
else:
start_epoch = 0
model.train()

self.start_epoch = start_epoch
self.model = model
self.optimizer = optimizer

def save_model(self, state, epoch):
file_path = osp.join(cfg.model_dir,'snapshot_{}.pth.tar'.format(str(epoch)))
torch.save(state, file_path)
self.logger.info("Write snapshot into {}".format(file_path))

def load_model(self, model, optimizer):
model_file_list = glob.glob(osp.join(cfg.model_dir,'*.pth.tar'))
cur_epoch = max([int(file_name[file_name.find('snapshot_') + 9 : file_name.find('.pth.tar')]) for file_name in model_file_list])
model_path = osp.join(cfg.model_dir, 'snapshot_' + str(cur_epoch) + '.pth.tar')
self.logger.info('Load checkpoint from {}'.format(model_path))
ckpt = torch.load(model_path)
start_epoch = ckpt['epoch'] + 1

model.load_state_dict(ckpt['network'])
try:
optimizer.load_state_dict(ckpt['optimizer'])
except:
pass

return start_epoch, model, optimizer


class Tester(Base):

def __init__(self):
super(Tester, self).__init__(log_name = 'test_logs.txt')

def _make_batch_generator(self, test_set):
# data load and construct batch generator
self.logger.info("Creating " + test_set + " dataset...")
testset_loader = Dataset(transforms.ToTensor(), test_set)
batch_generator = DataLoader(dataset=testset_loader, batch_size=cfg.num_gpus*cfg.test_batch_size, shuffle=False, num_workers=cfg.num_thread, pin_memory=True)

self.joint_num = testset_loader.joint_num
self.batch_generator = batch_generator
self.testset = testset_loader

def _make_model(self):
model_path = os.path.join(cfg.model_dir, 'snapshot.pth.tar')
assert os.path.exists(model_path), 'Cannot find model at ' + model_path
self.logger.info('Load checkpoint from {}'.format(model_path))

# prepare network
self.logger.info("Creating graph...")
model = get_model('test', self.joint_num)
model = DataParallel(model).cuda()
ckpt = torch.load(model_path)
model.load_state_dict(ckpt['network'])
model.eval()

self.model = model

def _evaluate(self, preds):
mpjpe_dict = self.testset.evaluate(preds)
return mpjpe_dict
Loading

0 comments on commit 8032000

Please sign in to comment.