diff --git a/README.md b/README.md new file mode 100644 index 0000000..45d3d6c --- /dev/null +++ b/README.md @@ -0,0 +1,85 @@ +# A2J-Transformer + +## Introduction +This is the official implementation for the paper, **"A2J-Transformer: Anchor-to-Joint Transformer Network for 3D Interacting Hand Pose Estimation from a Single RGB Image"**, CVPR 2023. + +# About our code + + +## Installation and Setup + +### Requirements + +* Our code is tested under Ubuntu 20.04 environment with NVIDIA 2080Ti GPU and NVIDIA 3090 GPU, both Pytorch1.7 and Pytorch1.11 work. + +* Python>=3.7 + + We recommend you to use Anaconda to create a conda environment: + ```bash + conda create --name a2j_trans python=3.7 + ``` + Then, activate the environment: + ```bash + conda activate a2j_trans + ``` + +* PyTorch>=1.7.1, torchvision>=0.8.2 (following instructions [here](https://pytorch.org/)) + + We recommend you to use the following pytorch and torchvision: + ```bash + conda install pytorch==1.7.1 torchvision==0.8.2 torchaudio==0.7.2 cudatoolkit=11.0 -c pytorch + ``` + +* Other requirements + ```bash + conda install tqdm numpy matplotlib scipy + pip install opencv-python pycocotools + ``` + +### Compiling CUDA operators(Following [Deformable-DETR](https://github.com/fundamentalvision/Deformable-DETR)) +```bash +cd ./dab_deformable_detr/ops +sh make.sh +``` + +## Usage + +### Dataset preparation + +* Please download [InterHand 2.6M Dataset](https://mks0601.github.io/InterHand2.6M/) and organize them as following: + + ``` + your_dataset_path/ + └── Interhand2.6M_5fps/ + ├── annotations/ + └── images/ + ``` + + + +### Testing on InterHand 2.6M Dataset + +* Please download our [pre-trained model](https://drive.google.com/file/d/1QKqokPnSkWMRJjZkj04Nhf0eQCl66-6r/view?usp=share_link) and organize the code as following: + + ``` + a2j-transformer/ + ├── dab_deformable_detr/ + ├── nets/ + ├── utils/ + ├── ...py + ├── datalist/ + | └── ...pkl + └── output/ + └── model_dump/ + └── snapshot.pth.tar + ``` + The `datalist` folder and the pkl files denotes the dataset-list generated during running the code. + You can choose to download them [here](https://drive.google.com/file/d/1pfghhGnS5wI23UtF3a4IgBbXz-e2hgYI/view?usp=share_link), and manually put them under the `datalist` folder. + +* In `config.py`, set `interhand_anno_dir`, `interhand_images_path` to the dataset abs directory. +* In `config.py`, set `cur_dir` to the a2j-transformer code directory. +* Run the following script: + ```python + python test.py --gpu + ``` + You can also choose to change the `gpu_ids` in `test.py`. \ No newline at end of file diff --git a/anchor.py b/anchor.py new file mode 100644 index 0000000..a82485c --- /dev/null +++ b/anchor.py @@ -0,0 +1,45 @@ +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from config import cfg +from nets.layer import make_linear_layers + + +## generate 3D coords of 3D anchors. num is 256*3 +def generate_all_anchors_3d(): + x_center = 7.5 + y_center = 7.5 + stride = 16 + step_h = 16 + step_w = 16 + + d_center = 63.5 + stride_d = 64 + step_d = 3 + + anchors_h = np.arange(0,step_h) * stride + x_center + anchors_w = np.arange(0,step_w) * stride + y_center + anchors_d = np.arange(0,step_d) * stride_d + d_center + anchors_x, anchors_y, anchors_z = np.meshgrid(anchors_h, anchors_w, anchors_d) + all_anchors = np.vstack((anchors_x.ravel(), anchors_y.ravel(), anchors_z.ravel())).transpose() #256*3 + return all_anchors + + +class generate_keypoints_coord_new(nn.Module): + def __init__(self, num_joints, is_3D=True): + super(generate_keypoints_coord_new, self).__init__() + self.is_3D = is_3D + self.num_joints = num_joints + + def forward(self, total_coords, total_weights, total_references): + lvl_num, batch_size, a, _ = total_coords.shape + total_coords = total_coords.reshape(lvl_num, batch_size, a, self.num_joints, -1) ## l,b,a,j,3 + + weights_softmax = F.softmax(total_weights, dim=2) + weights = torch.unsqueeze(weights_softmax, dim=4).expand(-1,-1,-1,-1, 3) ## l,b,a,j,3 + + keypoints = torch.unsqueeze(total_references, dim = 3).expand(-1,-1,-1,42,-1) + total_coords + pred_keypoints = (keypoints * weights).sum(2) ## l,b,a,3 + anchors = (torch.unsqueeze(total_references, dim = 3) * weights).sum(2) + return pred_keypoints, anchors \ No newline at end of file diff --git a/base.py b/base.py new file mode 100644 index 0000000..32fa0b8 --- /dev/null +++ b/base.py @@ -0,0 +1,159 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +import os +import os.path as osp +import math +import time +import glob +import abc +from torch.utils.data import DataLoader +import torch.optim +import torchvision.transforms as transforms + +from config import cfg +from dataset import Dataset +from timer import Timer +from logger import colorlogger +from torch.nn.parallel.data_parallel import DataParallel +from model import get_model + +class Base(object): + __metaclass__ = abc.ABCMeta + + def __init__(self, log_name='logs.txt'): + + self.cur_epoch = 0 + + # timer + self.tot_timer = Timer() + self.gpu_timer = Timer() + self.read_timer = Timer() + + # logger + self.logger = colorlogger(cfg.log_dir, log_name=log_name) + + @abc.abstractmethod + def _make_batch_generator(self): + return + + @abc.abstractmethod + def _make_model(self): + return + + +class Trainer(Base): + + def __init__(self): + super(Trainer, self).__init__(log_name = 'train_logs.txt') + + def get_optimizer(self, model): + optimizer = torch.optim.Adam(model.parameters(), lr=cfg.lr) + return optimizer + + def set_lr(self, epoch): + if len(cfg.lr_dec_epoch) == 0: + return cfg.lr + + for e in cfg.lr_dec_epoch: + if epoch < e: + break + if epoch < cfg.lr_dec_epoch[-1]: + idx = cfg.lr_dec_epoch.index(e) + for g in self.optimizer.param_groups: + g['lr'] = cfg.lr / (cfg.lr_dec_factor ** idx) + else: + for g in self.optimizer.param_groups: + g['lr'] = cfg.lr / (cfg.lr_dec_factor ** len(cfg.lr_dec_epoch)) + + def get_lr(self): + for g in self.optimizer.param_groups: + cur_lr = g['lr'] + + return cur_lr + def _make_batch_generator(self): + # data load and construct batch generator + self.logger.info("Creating train dataset...") + trainset_loader = Dataset(transforms.ToTensor(), "train") + batch_generator = DataLoader(dataset=trainset_loader, batch_size=cfg.num_gpus*cfg.train_batch_size, shuffle=True, + num_workers=cfg.num_thread, pin_memory=True, drop_last=True) + + self.joint_num = trainset_loader.joint_num + self.itr_per_epoch = math.ceil(trainset_loader.__len__() / cfg.num_gpus / cfg.train_batch_size) + self.batch_generator = batch_generator + + def _make_model(self): + # prepare network + self.logger.info("Creating graph and optimizer...") + model = get_model('train', self.joint_num) + model = DataParallel(model).cuda() + optimizer = self.get_optimizer(model) + if cfg.continue_train: + start_epoch, model, optimizer = self.load_model(model, optimizer) + else: + start_epoch = 0 + model.train() + + self.start_epoch = start_epoch + self.model = model + self.optimizer = optimizer + + def save_model(self, state, epoch): + file_path = osp.join(cfg.model_dir,'snapshot_{}.pth.tar'.format(str(epoch))) + torch.save(state, file_path) + self.logger.info("Write snapshot into {}".format(file_path)) + + def load_model(self, model, optimizer): + model_file_list = glob.glob(osp.join(cfg.model_dir,'*.pth.tar')) + cur_epoch = max([int(file_name[file_name.find('snapshot_') + 9 : file_name.find('.pth.tar')]) for file_name in model_file_list]) + model_path = osp.join(cfg.model_dir, 'snapshot_' + str(cur_epoch) + '.pth.tar') + self.logger.info('Load checkpoint from {}'.format(model_path)) + ckpt = torch.load(model_path) + start_epoch = ckpt['epoch'] + 1 + + model.load_state_dict(ckpt['network']) + try: + optimizer.load_state_dict(ckpt['optimizer']) + except: + pass + + return start_epoch, model, optimizer + + +class Tester(Base): + + def __init__(self): + super(Tester, self).__init__(log_name = 'test_logs.txt') + + def _make_batch_generator(self, test_set): + # data load and construct batch generator + self.logger.info("Creating " + test_set + " dataset...") + testset_loader = Dataset(transforms.ToTensor(), test_set) + batch_generator = DataLoader(dataset=testset_loader, batch_size=cfg.num_gpus*cfg.test_batch_size, shuffle=False, num_workers=cfg.num_thread, pin_memory=True) + + self.joint_num = testset_loader.joint_num + self.batch_generator = batch_generator + self.testset = testset_loader + + def _make_model(self): + model_path = os.path.join(cfg.model_dir, 'snapshot.pth.tar') + assert os.path.exists(model_path), 'Cannot find model at ' + model_path + self.logger.info('Load checkpoint from {}'.format(model_path)) + + # prepare network + self.logger.info("Creating graph...") + model = get_model('test', self.joint_num) + model = DataParallel(model).cuda() + ckpt = torch.load(model_path) + model.load_state_dict(ckpt['network']) + model.eval() + + self.model = model + + def _evaluate(self, preds): + mpjpe_dict = self.testset.evaluate(preds) + return mpjpe_dict \ No newline at end of file diff --git a/config.py b/config.py new file mode 100644 index 0000000..1c60977 --- /dev/null +++ b/config.py @@ -0,0 +1,141 @@ +import os +import os.path as osp +import sys +import math +import numpy as np + +def clean_file(path): + ## Clear the files under the path + for i in os.listdir(path): + content_path = os.path.join(path, i) + if os.path.isdir(content_path): + clean_file(content_path) + else: + assert os.path.isfile(content_path) is True + os.remove(content_path) + + + +class Config: + # ~~~~~~~~~~~~~~~~~~~~~~Dataset~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~# + dataset = 'InterHand2.6M' # InterHand2.6M + pose_representation = '2p5D' #2p5D + + + # ~~~~~~~~~~~~~~~~~~~~~~ paths~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~# + ## Please set your path + ## Interhand2.6M dataset path. you should change to your dataset path. + interhand_anno_dir = '/data/data1/Interhand2.6M_5fps/annotations' + interhand_images_path = '/data/data1/Interhand2.6M_5fps/images' + ## current file dir. change this path to your A2J-Transformer folder dir. + cur_dir = '/data/data2/a2jformer/camera_ready' + + + # ~~~~~~~~~~~~~~~~~~~~~~~~input, output~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~# + input_img_shape = (256, 256) + output_hm_shape = (256, 256, 256) # (depth, height, width) + output_hm_shape_all = 256 ## For convenient + sigma = 2.5 + bbox_3d_size = 400 # depth axis + bbox_3d_size_root = 400 # depth axis + output_root_hm_shape = 64 # depth axis + + + # ~~~~~~~~~~~~~~~~~~~~~~~~backbone config~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~# + num_feature_levels = 4 + lr_backbone = 1e-4 + masks = False + backbone = 'resnet50' + dilation = True # If true, we replace stride with dilation in the last convolutional block (DC5) + + + # ~~~~~~~~~~~~~~~~~~~~~~~~transformer config~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~# + position_embedding = 'sine' #'sine' #'convLearned' # learned + hidden_dim = 256 + dropout = 0.1 + nheads = 8 + dim_feedforward = 1024 + enc_layers = 6 + dec_layers = 6 + pre_norm = False + num_feature_levels = 4 + dec_n_points = 4 + enc_n_points = 4 + num_queries = 768 ## query numbers, default is 256*3 = 768 + kernel_size = 256 + two_stage = False ## Whether to use the two-stage deformable-detr, please select False. + use_dab = True ## Whether to use dab-detr, please select True. + num_patterns = 0 + anchor_refpoints_xy = True ## Whether to use the anchor anchor point as the reference point coordinate, True. + is_3D = True # True + fix_anchor = True ## Whether to fix the position of reference points to prevent update, True. + use_lvl_weights = False ## Whether to assign different weights to the loss of each layer, the improvement is relatively limited. + lvl_weights = [0.1, 0.15, 0.15, 0.15, 0.15, 0.3] + + + # ~~~~~~~~~~~~~~~~~~~~~~~~a2j config~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~# + RegLossFactor = 3 + + + # ~~~~~~~~~~~~~~~~~~~~~~~~training config~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~# + lr_dec_epoch = [24, 35] if dataset == 'InterHand2.6M' else [45,47] + end_epoch = 42 if dataset == 'InterHand2.6M' else 50 + lr = 1e-4 + lr_dec_factor = 5 + train_batch_size = 12 + continue_train = False ## Whether to continue training, default is False + + + # ~~~~~~~~~~~~~~~~~~~~~~~~testing config~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~# + test_batch_size = 48 + trans_test = 'gt' ## 'gt', 'rootnet' # 'rootnet' is not used + + + # ~~~~~~~~~~~~~~~~~~~~~~~~dataset config~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~# + use_single_hand_dataset = True ## Use single-handed data, default is True + use_inter_hand_dataset = True ## Using interacting hand data, default is True + + + # ~~~~~~~~~~~~~~~~~~~~~~~~others~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~# + num_thread = 8 + gpu_ids = '0' ## your gpu ids, for example, '0', '1-3' + num_gpus = 1 + + + # ~~~~~~~~~~~~~~~~~~~~~~~~directory setup~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~# + data_dir = osp.join(cur_dir, 'data') + output_dir = osp.join(cur_dir, 'output') + datalistDir = osp.join(cur_dir, 'datalist') ## this is used to save the dataset datalist, easy to debug. + vis_2d_dir = osp.join(output_dir, 'vis_2d') + vis_3d_dir = osp.join(output_dir, 'vis_3d') + log_dir = osp.join(output_dir, 'log') + result_dir = osp.join(output_dir, 'result') + model_dir = osp.join(output_dir, 'model_dump') + tensorboard_dir = osp.join(output_dir, 'tensorboard_log') + clean_tensorboard_dir = False + clean_log_dir = False + if clean_tensorboard_dir is True: + clean_file(tensorboard_dir) + if clean_log_dir is True: + clean_file(log_dir) + + + def set_args(self, gpu_ids, continue_train=False): + self.gpu_ids = gpu_ids + self.num_gpus = len(self.gpu_ids.split(',')) + self.continue_train = continue_train + os.environ["CUDA_VISIBLE_DEVICES"] = self.gpu_ids + print('>>> Using GPU: {}'.format(self.gpu_ids)) + + +cfg = Config() +from utils.dir import add_pypath, make_folder +add_pypath(osp.join(cfg.data_dir)) +add_pypath(osp.join(cfg.data_dir, cfg.dataset)) +make_folder(cfg.datalistDir) +make_folder(cfg.model_dir) +make_folder(cfg.vis_2d_dir) +make_folder(cfg.vis_3d_dir) +make_folder(cfg.log_dir) +make_folder(cfg.result_dir) +make_folder(cfg.tensorboard_dir) \ No newline at end of file diff --git a/dab_deformable_detr/__init__.py b/dab_deformable_detr/__init__.py new file mode 100644 index 0000000..0eaf756 --- /dev/null +++ b/dab_deformable_detr/__init__.py @@ -0,0 +1,17 @@ +# ------------------------------------------------------------------------ +# DAB-DETR +# Copyright (c) 2022 IDEA. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ +# DModified from eformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ +# Modified from DETR (https://github.com/facebookresearch/detr) +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +# ------------------------------------------------------------------------ + +# from .dab_deformable_detr import build_dab_deformable_detr + + + diff --git a/dab_deformable_detr/backbone.py b/dab_deformable_detr/backbone.py new file mode 100644 index 0000000..4af65e9 --- /dev/null +++ b/dab_deformable_detr/backbone.py @@ -0,0 +1,142 @@ +# ------------------------------------------------------------------------ +# DAB-DETR +# Copyright (c) 2022 IDEA. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ +# Modified from Deformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ +# Modified from DETR (https://github.com/facebookresearch/detr) +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +# ------------------------------------------------------------------------ + +""" +Backbone modules. +""" +from collections import OrderedDict + +import torch +import torch.nn.functional as F +import torchvision +from torch import nn +from torchvision.models._utils import IntermediateLayerGetter +from typing import Dict, List + +from utils.miscdetr import NestedTensor, is_main_process + +from .position_encoding import build_position_encoding + + +class FrozenBatchNorm2d(torch.nn.Module): + """ + BatchNorm2d where the batch statistics and the affine parameters are fixed. + + Copy-paste from torchvision.misc.ops with added eps before rqsrt, + without which any other models than torchvision.models.resnet[18,34,50,101] + produce nans. + """ + + def __init__(self, n, eps=1e-5): + super(FrozenBatchNorm2d, self).__init__() + self.register_buffer("weight", torch.ones(n)) + self.register_buffer("bias", torch.zeros(n)) + self.register_buffer("running_mean", torch.zeros(n)) + self.register_buffer("running_var", torch.ones(n)) + self.eps = eps + + def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, + missing_keys, unexpected_keys, error_msgs): + num_batches_tracked_key = prefix + 'num_batches_tracked' + if num_batches_tracked_key in state_dict: + del state_dict[num_batches_tracked_key] + + super(FrozenBatchNorm2d, self)._load_from_state_dict( + state_dict, prefix, local_metadata, strict, + missing_keys, unexpected_keys, error_msgs) + + def forward(self, x): + # move reshapes to the beginning + # to make it fuser-friendly + w = self.weight.reshape(1, -1, 1, 1) + b = self.bias.reshape(1, -1, 1, 1) + rv = self.running_var.reshape(1, -1, 1, 1) + rm = self.running_mean.reshape(1, -1, 1, 1) + eps = self.eps + scale = w * (rv + eps).rsqrt() + bias = b - rm * scale + return x * scale + bias + + +class BackboneBase(nn.Module): + + def __init__(self, backbone: nn.Module, train_backbone: bool, return_interm_layers: bool): + super().__init__() + for name, parameter in backbone.named_parameters(): + if not train_backbone or 'layer2' not in name and 'layer3' not in name and 'layer4' not in name: + parameter.requires_grad_(False) + if return_interm_layers: + # return_layers = {"layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3"} + return_layers = {"layer2": "0", "layer3": "1", "layer4": "2"} + self.strides = [8, 16, 32] + self.num_channels = [512, 1024, 2048] + else: + return_layers = {'layer4': "0"} + self.strides = [32] + self.num_channels = [2048] + self.body = IntermediateLayerGetter(backbone, return_layers=return_layers) + + def forward(self, tensor_list: NestedTensor): + xs = self.body(tensor_list.tensors) + out: Dict[str, NestedTensor] = {} + for name, x in xs.items(): + m = tensor_list.mask + assert m is not None + mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0] + out[name] = NestedTensor(x, mask) + return out + + +class Backbone(BackboneBase): + """ResNet backbone with frozen BatchNorm.""" + def __init__(self, name: str, + train_backbone: bool, + return_interm_layers: bool, + dilation: bool): + norm_layer = FrozenBatchNorm2d + backbone = getattr(torchvision.models, name)( + replace_stride_with_dilation=[False, False, dilation], + pretrained=is_main_process(), norm_layer=norm_layer) + assert name not in ('resnet18', 'resnet34'), "number of channels are hard coded" + super().__init__(backbone, train_backbone, return_interm_layers) + if dilation: + self.strides[-1] = self.strides[-1] // 2 + + +class Joiner(nn.Sequential): + def __init__(self, backbone, position_embedding): + super().__init__(backbone, position_embedding) + self.strides = backbone.strides + self.num_channels = backbone.num_channels + + def forward(self, tensor_list: NestedTensor): + xs = self[0](tensor_list) + out: List[NestedTensor] = [] + pos = [] + for name, x in sorted(xs.items()): + out.append(x) + + # position encoding + for x in out: + pos.append(self[1](x).to(x.tensors.dtype)) + + return out, pos + + +def build_backbone(cfg): + position_embedding = build_position_encoding(cfg) + train_backbone = cfg.lr_backbone > 0 + return_interm_layers = cfg.masks or (cfg.num_feature_levels > 1) + backbone = Backbone(cfg.backbone, train_backbone, return_interm_layers, cfg.dilation) + model = Joiner(backbone, position_embedding) + return model diff --git a/dab_deformable_detr/deformable_transformer.py b/dab_deformable_detr/deformable_transformer.py new file mode 100644 index 0000000..af7dcd2 --- /dev/null +++ b/dab_deformable_detr/deformable_transformer.py @@ -0,0 +1,542 @@ +# ------------------------------------------------------------------------ +# DAB-DETR +# Copyright (c) 2022 IDEA. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ +# DModified from eformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ +# Modified from DETR (https://github.com/facebookresearch/detr) +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +# ------------------------------------------------------------------------ + +import copy +from typing import Optional, List +import math + +import torch +import torch.nn.functional as F +from torch import nn, Tensor +from torch.nn.init import xavier_uniform_, constant_, uniform_, normal_ + +from utils.miscdetr import inverse_sigmoid +from .ops.modules import MSDeformAttn + +from config import cfg + + +class DeformableTransformer(nn.Module): + def __init__(self, d_model=256, nhead=8, + num_encoder_layers=6, num_decoder_layers=6, dim_feedforward=1024, dropout=0.1, + activation="relu", return_intermediate_dec=False, + num_feature_levels=4, dec_n_points=4, enc_n_points=4, + two_stage=False, two_stage_num_proposals=300, + use_dab=False, fix_anchor=False, high_dim_query_update=False, no_sine_embed=False): + super().__init__() + + self.d_model = d_model + self.nhead = nhead + self.two_stage = two_stage # default False + self.two_stage_num_proposals = two_stage_num_proposals + self.use_dab = use_dab + self.fix_anchor = fix_anchor + + encoder_layer = DeformableTransformerEncoderLayer(d_model, dim_feedforward, + dropout, activation, + num_feature_levels, nhead, enc_n_points) + self.encoder = DeformableTransformerEncoder(encoder_layer, num_encoder_layers) + + decoder_layer = DeformableTransformerDecoderLayer(d_model, dim_feedforward, + dropout, activation, + num_feature_levels, nhead, dec_n_points) + self.decoder = DeformableTransformerDecoder(decoder_layer, num_decoder_layers, return_intermediate_dec, + use_dab=use_dab, d_model=d_model, high_dim_query_update=high_dim_query_update, + fix_anchor=fix_anchor, no_sine_embed=no_sine_embed) + + self.level_embed = nn.Parameter(torch.Tensor(num_feature_levels, d_model)) + + if two_stage: + self.enc_output = nn.Linear(d_model, d_model) + self.enc_output_norm = nn.LayerNorm(d_model) + self.pos_trans = nn.Linear(d_model * 2, d_model * 2) + self.pos_trans_norm = nn.LayerNorm(d_model * 2) + else: + if not self.use_dab: + self.reference_points = nn.Linear(d_model, 2) + + self.high_dim_query_update = high_dim_query_update + if high_dim_query_update: + assert not self.use_dab, "use_dab must be True" + + self._reset_parameters() + + def _reset_parameters(self): + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + for m in self.modules(): + if isinstance(m, MSDeformAttn): + m._reset_parameters() + if not self.two_stage and not self.use_dab: + xavier_uniform_(self.reference_points.weight.data, gain=1.0) + constant_(self.reference_points.bias.data, 0.) + normal_(self.level_embed) + + def get_proposal_pos_embed(self, proposals): + num_pos_feats = 128 + temperature = 10000 + scale = 2 * math.pi + + dim_t = torch.arange(num_pos_feats, dtype=torch.float32, device=proposals.device) + dim_t = temperature ** (2 * (dim_t // 2) / num_pos_feats) + # N, L, 4 + proposals = proposals.sigmoid() * scale + # N, L, 4, 128 + pos = proposals[:, :, :, None] / dim_t + # N, L, 4, 64, 2 + pos = torch.stack((pos[:, :, :, 0::2].sin(), pos[:, :, :, 1::2].cos()), dim=4).flatten(2) + return pos + + def gen_encoder_output_proposals(self, memory, memory_padding_mask, spatial_shapes): + N_, S_, C_ = memory.shape + base_scale = 4.0 + proposals = [] + _cur = 0 + for lvl, (H_, W_) in enumerate(spatial_shapes): + mask_flatten_ = memory_padding_mask[:, _cur:(_cur + H_ * W_)].view(N_, H_, W_, 1) + valid_H = torch.sum(~mask_flatten_[:, :, 0, 0], 1) + valid_W = torch.sum(~mask_flatten_[:, 0, :, 0], 1) + + grid_y, grid_x = torch.meshgrid(torch.linspace(0, H_ - 1, H_, dtype=torch.float32, device=memory.device), + torch.linspace(0, W_ - 1, W_, dtype=torch.float32, device=memory.device)) + grid = torch.cat([grid_x.unsqueeze(-1), grid_y.unsqueeze(-1)], -1) + + scale = torch.cat([valid_W.unsqueeze(-1), valid_H.unsqueeze(-1)], 1).view(N_, 1, 1, 2) + grid = (grid.unsqueeze(0).expand(N_, -1, -1, -1) + 0.5) / scale + wh = torch.ones_like(grid) * 0.05 * (2.0 ** lvl) + proposal = torch.cat((grid, wh), -1).view(N_, -1, 4) + proposals.append(proposal) + _cur += (H_ * W_) + output_proposals = torch.cat(proposals, 1) + output_proposals_valid = ((output_proposals > 0.01) & (output_proposals < 0.99)).all(-1, keepdim=True) + output_proposals = torch.log(output_proposals / (1 - output_proposals)) + output_proposals = output_proposals.masked_fill(memory_padding_mask.unsqueeze(-1), float('inf')) + output_proposals = output_proposals.masked_fill(~output_proposals_valid, float('inf')) + + output_memory = memory + output_memory = output_memory.masked_fill(memory_padding_mask.unsqueeze(-1), float(0)) + output_memory = output_memory.masked_fill(~output_proposals_valid, float(0)) + output_memory = self.enc_output_norm(self.enc_output(output_memory)) + return output_memory, output_proposals + + def get_valid_ratio(self, mask): + _, H, W = mask.shape + valid_H = torch.sum(~mask[:, :, 0], 1) + valid_W = torch.sum(~mask[:, 0, :], 1) + valid_ratio_h = valid_H.float() / H + valid_ratio_w = valid_W.float() / W + valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1) + return valid_ratio + + def forward(self, srcs, masks, pos_embeds, query_embed=None): + """ + Input: + - srcs: List([bs, c, h, w]) + - masks: List([bs, h, w]) + """ + assert self.two_stage or query_embed is not None + + # prepare input for encoder + src_flatten = [] + mask_flatten = [] + lvl_pos_embed_flatten = [] + spatial_shapes = [] + for lvl, (src, mask, pos_embed) in enumerate(zip(srcs, masks, pos_embeds)): + bs, c, h, w = src.shape + spatial_shape = (h, w) + spatial_shapes.append(spatial_shape) + + src = src.flatten(2).transpose(1, 2) # bs, hw, c + mask = mask.flatten(1) # bs, hw + pos_embed = pos_embed.flatten(2).transpose(1, 2) # bs, hw, c + lvl_pos_embed = pos_embed + self.level_embed[lvl].view(1, 1, -1) + lvl_pos_embed_flatten.append(lvl_pos_embed) + src_flatten.append(src) + mask_flatten.append(mask) + src_flatten = torch.cat(src_flatten, 1) # bs, \sum{hxw}, c + mask_flatten = torch.cat(mask_flatten, 1) # bs, \sum{hxw} + lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1) + + spatial_shapes = torch.as_tensor(spatial_shapes, dtype=torch.long, device=src_flatten.device) + level_start_index = torch.cat((spatial_shapes.new_zeros((1, )), spatial_shapes.prod(1).cumsum(0)[:-1])) + valid_ratios = torch.stack([self.get_valid_ratio(m) for m in masks], 1) + + # encoder + memory = self.encoder(src_flatten, spatial_shapes, level_start_index, valid_ratios, lvl_pos_embed_flatten, mask_flatten) + + # prepare input for decoder + bs, _, c = memory.shape + + if self.two_stage: + # not used + output_memory, output_proposals = self.gen_encoder_output_proposals(memory, mask_flatten, spatial_shapes) + + # hack implementation for two-stage Deformable DETR + enc_outputs_class = self.decoder.class_embed[self.decoder.num_layers](output_memory) + enc_outputs_coord_unact = self.decoder.bbox_embed[self.decoder.num_layers](output_memory) + output_proposals + + topk = self.two_stage_num_proposals + topk_proposals = torch.topk(enc_outputs_class[..., 0], topk, dim=1)[1] # topk函数用来求top k个极大or极小值 + topk_coords_unact = torch.gather(enc_outputs_coord_unact, 1, topk_proposals.unsqueeze(-1).repeat(1, 1, 4)) + topk_coords_unact = topk_coords_unact.detach() + reference_points = topk_coords_unact.sigmoid() + init_reference_out = reference_points + pos_trans_out = self.pos_trans_norm(self.pos_trans(self.get_proposal_pos_embed(topk_coords_unact))) + query_embed, tgt = torch.split(pos_trans_out, c, dim=2) + + elif self.use_dab: + ## using + reference_points = query_embed[..., self.d_model:] + tgt = query_embed[..., :self.d_model] + tgt = tgt.unsqueeze(0).expand(bs, -1, -1) + init_reference_out = reference_points + + else: + # not used + query_embed, tgt = torch.split(query_embed, c, dim=1) + query_embed = query_embed.unsqueeze(0).expand(bs, -1, -1) + tgt = tgt.unsqueeze(0).expand(bs, -1, -1) + reference_points = self.reference_points(query_embed).sigmoid() + # bs, num_quires, 2 + init_reference_out = reference_points + + # decoder + hs, inter_references = self.decoder(tgt, reference_points, memory, + spatial_shapes, level_start_index, valid_ratios, + query_pos=query_embed if not self.use_dab else None, # dab不需要由query得到位置emb, + src_padding_mask=mask_flatten) + + inter_references_out = inter_references + if self.two_stage: + return hs, init_reference_out, inter_references_out, enc_outputs_class, enc_outputs_coord_unact + return hs, init_reference_out, inter_references_out, None, None + + +class DeformableTransformerEncoderLayer(nn.Module): + def __init__(self, + d_model=256, d_ffn=1024, + dropout=0.1, activation="relu", + n_levels=4, n_heads=8, n_points=4): + super().__init__() + + # self attention + self.self_attn = MSDeformAttn(d_model, n_levels, n_heads, n_points) + self.dropout1 = nn.Dropout(dropout) + self.norm1 = nn.LayerNorm(d_model) + + # ffn + self.linear1 = nn.Linear(d_model, d_ffn) + self.activation = _get_activation_fn(activation) + self.dropout2 = nn.Dropout(dropout) + self.linear2 = nn.Linear(d_ffn, d_model) + self.dropout3 = nn.Dropout(dropout) + self.norm2 = nn.LayerNorm(d_model) + + @staticmethod + def with_pos_embed(tensor, pos): + return tensor if pos is None else tensor + pos + + def forward_ffn(self, src): + src2 = self.linear2(self.dropout2(self.activation(self.linear1(src)))) + src = src + self.dropout3(src2) + src = self.norm2(src) + return src + + def forward(self, src, pos, reference_points, spatial_shapes, level_start_index, padding_mask=None): + # self attention + src2 = self.self_attn(self.with_pos_embed(src, pos), reference_points, src, spatial_shapes, level_start_index, padding_mask) + src = src + self.dropout1(src2) + src = self.norm1(src) + + # ffn + src = self.forward_ffn(src) + + return src + + +class DeformableTransformerEncoder(nn.Module): + def __init__(self, encoder_layer, num_layers): + super().__init__() + self.layers = _get_clones(encoder_layer, num_layers) + self.num_layers = num_layers + + @staticmethod + def get_reference_points(spatial_shapes, valid_ratios, device): + reference_points_list = [] + for lvl, (H_, W_) in enumerate(spatial_shapes): + + ref_y, ref_x = torch.meshgrid(torch.linspace(0.5, H_ - 0.5, H_, dtype=torch.float32, device=device), + torch.linspace(0.5, W_ - 0.5, W_, dtype=torch.float32, device=device)) + ref_y = ref_y.reshape(-1)[None] / (valid_ratios[:, None, lvl, 1] * H_) + ref_x = ref_x.reshape(-1)[None] / (valid_ratios[:, None, lvl, 0] * W_) + ref = torch.stack((ref_x, ref_y), -1) + reference_points_list.append(ref) + reference_points = torch.cat(reference_points_list, 1) + reference_points = reference_points[:, :, None] * valid_ratios[:, None] + return reference_points + + def forward(self, src, spatial_shapes, level_start_index, valid_ratios, pos=None, padding_mask=None): + """ + Input: + - src: [bs, sum(hi*wi), 256] + - spatial_shapes: h,w of each level [num_level, 2] + - level_start_index: [num_level] start point of level in sum(hi*wi). + - valid_ratios: [bs, num_level, 2] + - pos: pos embed for src. [bs, sum(hi*wi), 256] + - padding_mask: [bs, sum(hi*wi)] + Intermedia: + - reference_points: [bs, sum(hi*wi), num_lebel, 2] + """ + output = src + # bs, sum(hi*wi), 256 + # import ipdb; ipdb.set_trace() + reference_points = self.get_reference_points(spatial_shapes, valid_ratios, device=src.device) + for _, layer in enumerate(self.layers): + output = layer(output, pos, reference_points, spatial_shapes, level_start_index, padding_mask) + + return output + + +class DeformableTransformerDecoderLayer(nn.Module): + def __init__(self, d_model=256, d_ffn=1024, + dropout=0.1, activation="relu", + n_levels=4, n_heads=8, n_points=4): + super().__init__() + + # cross attention + self.cross_attn = MSDeformAttn(d_model, n_levels, n_heads, n_points) + self.dropout1 = nn.Dropout(dropout) + self.norm1 = nn.LayerNorm(d_model) + + # self attention + self.self_attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout) + self.dropout2 = nn.Dropout(dropout) + self.norm2 = nn.LayerNorm(d_model) + + # ffn + self.linear1 = nn.Linear(d_model, d_ffn) + self.activation = _get_activation_fn(activation) + self.dropout3 = nn.Dropout(dropout) + self.linear2 = nn.Linear(d_ffn, d_model) + self.dropout4 = nn.Dropout(dropout) + self.norm3 = nn.LayerNorm(d_model) + + @staticmethod + def with_pos_embed(tensor, pos): + return tensor if pos is None else tensor + pos + + def forward_ffn(self, tgt): + tgt2 = self.linear2(self.dropout3(self.activation(self.linear1(tgt)))) + tgt = tgt + self.dropout4(tgt2) + tgt = self.norm3(tgt) + return tgt + + def forward(self, tgt, query_pos, reference_points, src, src_spatial_shapes, level_start_index, src_padding_mask=None): + # self attention + q = k = self.with_pos_embed(tgt, query_pos) + tgt2 = self.self_attn(q.transpose(0, 1), k.transpose(0, 1), tgt.transpose(0, 1))[0].transpose(0, 1) + tgt = tgt + self.dropout2(tgt2) + tgt = self.norm2(tgt) + + # cross attention + tgt2 = self.cross_attn(self.with_pos_embed(tgt, query_pos), + reference_points, + src, src_spatial_shapes, level_start_index, src_padding_mask) + tgt = tgt + self.dropout1(tgt2) + tgt = self.norm1(tgt) + + # ffn + tgt = self.forward_ffn(tgt) + + return tgt + + +class DeformableTransformerDecoder(nn.Module): + def __init__(self, decoder_layer, num_layers, return_intermediate=False, use_dab=False, d_model=256, high_dim_query_update=False, fix_anchor=False, no_sine_embed=False): + super().__init__() + self.layers = _get_clones(decoder_layer, num_layers) + self.num_layers = num_layers + self.return_intermediate = return_intermediate + # hack implementation for iterative bounding box refinement and two-stage Deformable DETR + self.bbox_embed = None + self.class_embed = None + self.use_dab = use_dab + self.d_model = d_model + self.no_sine_embed = no_sine_embed + self.fix_anchor = fix_anchor + if use_dab: + self.query_scale = MLP(d_model, d_model, d_model, 2) + if self.no_sine_embed: + self.ref_point_head = MLP(2, d_model, d_model, 2) # 这里2表示xy,去掉hw + else: + # self.ref_point_head = MLP(2 * d_model, d_model, d_model, 2) + # self.ref_point_head = MLP(d_model, d_model, d_model, 2) + dmodel_for_3d_anchor = int(d_model * 1.5) + self.ref_point_head = MLP(dmodel_for_3d_anchor, d_model, d_model, 2) + + self.high_dim_query_update = high_dim_query_update + if high_dim_query_update: + self.high_dim_query_proj = MLP(d_model, d_model, d_model, 2) + + + def forward(self, tgt, reference_points, src, src_spatial_shapes, + src_level_start_index, src_valid_ratios, + query_pos=None, src_padding_mask=None): + output = tgt + if self.use_dab: + assert query_pos is None + bs = src.shape[0] + reference_points = reference_points[None].repeat(bs, 1, 1) # bs, nq, 4(xywh) + + + intermediate = [] + intermediate_reference_points = [] + for lid, layer in enumerate(self.layers): + # import ipdb; ipdb.set_trace() + if reference_points.shape[-1] == 4: + reference_points_input = reference_points[:, :, None] \ + * torch.cat([src_valid_ratios, src_valid_ratios], -1)[:, None] # bs, nq, 4, 4 + elif reference_points.shape[-1] == 2: ## xy + reference_points_input = reference_points[:, :, None] * src_valid_ratios[:, None] + elif reference_points.shape[-1] == 3: ## xyz + reference_points_input = reference_points.unsqueeze(2).repeat(1,1,4,1) + if self.use_dab: + # import ipdb; ipdb.set_trace() + if self.no_sine_embed: + raw_query_pos = self.ref_point_head(reference_points_input) + else: + query_sine_embed = gen_sineembed_for_position(reference_points_input[:, :, 0, :]) # bs, nq, 384(128*3) + raw_query_pos = self.ref_point_head(query_sine_embed) # bs, nq, 256 + pos_scale = self.query_scale(output) if lid != 0 else 1 + + query_pos = pos_scale * raw_query_pos + + # Update the features of the query + if self.high_dim_query_update and lid != 0: + query_pos = query_pos + self.high_dim_query_proj(output) + + if reference_points.shape[-1] == 3: + reference_points_input = reference_points_input[:,:,:,0:2] + output = layer(output, query_pos, reference_points_input, src, src_spatial_shapes, src_level_start_index, src_padding_mask) + + # hack implementation for iterative bounding box refinement + if self.bbox_embed is not None: + tmp = self.bbox_embed[lid](output) + if reference_points.shape[-1] == 4: ## ori dab-detr + new_reference_points = tmp + reference_points # inverse sigmoid is not need + + else: + assert reference_points.shape[-1] == 3 + if self.fix_anchor is True: ## not update reference_points + new_reference_points = reference_points + else: + new_reference_points = tmp + reference_points # inverse sigmoid is not need + reference_points = new_reference_points.detach() + + if self.return_intermediate: + intermediate.append(output) + intermediate_reference_points.append(reference_points) + + if self.return_intermediate: + return torch.stack(intermediate), torch.stack(intermediate_reference_points) + + return output, reference_points + + +def _get_clones(module, N): + return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) + + +def _get_activation_fn(activation): + """Return an activation function given a string""" + if activation == "relu": + return F.relu + if activation == "gelu": + return F.gelu + if activation == "glu": + return F.glu + raise RuntimeError(F"activation should be relu/gelu, not {activation}.") + + +def build_deforamble_transformer(cfg): + return DeformableTransformer( + d_model=cfg.hidden_dim, + nhead=cfg.nheads, + num_encoder_layers=cfg.enc_layers, + num_decoder_layers=cfg.dec_layers, + dim_feedforward=cfg.dim_feedforward, + dropout=cfg.dropout, + activation="relu", + return_intermediate_dec=True, + num_feature_levels=cfg.num_feature_levels, + dec_n_points=cfg.dec_n_points, + enc_n_points=cfg.enc_n_points, + two_stage=cfg.two_stage, + two_stage_num_proposals=cfg.num_queries, + use_dab=cfg.use_dab, + fix_anchor=cfg.fix_anchor) + + +class MLP(nn.Module): + """ Very simple multi-layer perceptron (also called FFN)""" + + def __init__(self, input_dim, hidden_dim, output_dim, num_layers): + super().__init__() + self.num_layers = num_layers + h = [hidden_dim] * (num_layers - 1) + self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])) + + def forward(self, x): + for i, layer in enumerate(self.layers): + x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) + return x + +def gen_sineembed_for_position(pos_tensor): + # n_query, bs, _ = pos_tensor.size() + # sineembed_tensor = torch.zeros(n_query, bs, 256) + scale = 2 * math.pi + dim_t = torch.arange(128, dtype=torch.float32, device=pos_tensor.device) + dim_t = 10000 ** (2 * (dim_t // 2) / 128) + x_embed = pos_tensor[:, :, 0] * scale + y_embed = pos_tensor[:, :, 1] * scale + + pos_x = x_embed[:, :, None] / dim_t + pos_y = y_embed[:, :, None] / dim_t + + pos_x = torch.stack((pos_x[:, :, 0::2].sin(), pos_x[:, :, 1::2].cos()), dim=3).flatten(2) + pos_y = torch.stack((pos_y[:, :, 0::2].sin(), pos_y[:, :, 1::2].cos()), dim=3).flatten(2) + + if pos_tensor.size(-1) == 2: + pos = torch.cat((pos_y, pos_x), dim=2) + elif pos_tensor.size(-1) == 4: ## ori + w_embed = pos_tensor[:, :, 2] * scale + pos_w = w_embed[:, :, None] / dim_t + pos_w = torch.stack((pos_w[:, :, 0::2].sin(), pos_w[:, :, 1::2].cos()), dim=3).flatten(2) + + h_embed = pos_tensor[:, :, 3] * scale + pos_h = h_embed[:, :, None] / dim_t + pos_h = torch.stack((pos_h[:, :, 0::2].sin(), pos_h[:, :, 1::2].cos()), dim=3).flatten(2) + + pos = torch.cat((pos_y, pos_x, pos_w, pos_h), dim=2) + + elif pos_tensor.size(-1) == 3: ## my + z_embed = pos_tensor[:, :, 2] * scale + pos_z = z_embed[:, :, None] / dim_t + pos_z = torch.stack((pos_z[:, :, 0::2].sin(), pos_z[:, :, 1::2].cos()), dim=3).flatten(2) + + pos = torch.cat((pos_y, pos_x, pos_z), dim=2) + + else: + raise ValueError("Unknown pos_tensor shape(-1):{}".format(pos_tensor.size(-1))) + return pos \ No newline at end of file diff --git a/dab_deformable_detr/ops/MultiScaleDeformableAttention.egg-info/PKG-INFO b/dab_deformable_detr/ops/MultiScaleDeformableAttention.egg-info/PKG-INFO new file mode 100644 index 0000000..5f86c90 --- /dev/null +++ b/dab_deformable_detr/ops/MultiScaleDeformableAttention.egg-info/PKG-INFO @@ -0,0 +1,6 @@ +Metadata-Version: 2.1 +Name: MultiScaleDeformableAttention +Version: 1.0 +Summary: PyTorch Wrapper for CUDA Functions of Multi-Scale Deformable Attention +Home-page: https://github.com/fundamentalvision/Deformable-DETR +Author: Weijie Su diff --git a/dab_deformable_detr/ops/MultiScaleDeformableAttention.egg-info/SOURCES.txt b/dab_deformable_detr/ops/MultiScaleDeformableAttention.egg-info/SOURCES.txt new file mode 100644 index 0000000..b549de4 --- /dev/null +++ b/dab_deformable_detr/ops/MultiScaleDeformableAttention.egg-info/SOURCES.txt @@ -0,0 +1,15 @@ +setup.py +/data/data2/a2jformer/camera_ready/dab_deformable_detr/ops/src/vision.cpp +/data/data2/a2jformer/camera_ready/dab_deformable_detr/ops/src/cpu/ms_deform_attn_cpu.cpp +/data/data2/a2jformer/camera_ready/dab_deformable_detr/ops/src/cuda/ms_deform_attn_cuda.cu +/data/data2/a2jformer/code/dab_deformable_detr/ops/src/vision.cpp +/data/data2/a2jformer/code/dab_deformable_detr/ops/src/cpu/ms_deform_attn_cpu.cpp +/data/data2/a2jformer/code/dab_deformable_detr/ops/src/cuda/ms_deform_attn_cuda.cu +MultiScaleDeformableAttention.egg-info/PKG-INFO +MultiScaleDeformableAttention.egg-info/SOURCES.txt +MultiScaleDeformableAttention.egg-info/dependency_links.txt +MultiScaleDeformableAttention.egg-info/top_level.txt +functions/__init__.py +functions/ms_deform_attn_func.py +modules/__init__.py +modules/ms_deform_attn.py \ No newline at end of file diff --git a/dab_deformable_detr/ops/MultiScaleDeformableAttention.egg-info/dependency_links.txt b/dab_deformable_detr/ops/MultiScaleDeformableAttention.egg-info/dependency_links.txt new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/dab_deformable_detr/ops/MultiScaleDeformableAttention.egg-info/dependency_links.txt @@ -0,0 +1 @@ + diff --git a/dab_deformable_detr/ops/MultiScaleDeformableAttention.egg-info/top_level.txt b/dab_deformable_detr/ops/MultiScaleDeformableAttention.egg-info/top_level.txt new file mode 100644 index 0000000..25d8f77 --- /dev/null +++ b/dab_deformable_detr/ops/MultiScaleDeformableAttention.egg-info/top_level.txt @@ -0,0 +1,3 @@ +MultiScaleDeformableAttention +functions +modules diff --git a/dab_deformable_detr/ops/build/lib.linux-x86_64-3.7/MultiScaleDeformableAttention.cpython-37m-x86_64-linux-gnu.so b/dab_deformable_detr/ops/build/lib.linux-x86_64-3.7/MultiScaleDeformableAttention.cpython-37m-x86_64-linux-gnu.so new file mode 100644 index 0000000..a37a312 Binary files /dev/null and b/dab_deformable_detr/ops/build/lib.linux-x86_64-3.7/MultiScaleDeformableAttention.cpython-37m-x86_64-linux-gnu.so differ diff --git a/dab_deformable_detr/ops/build/lib.linux-x86_64-3.7/functions/__init__.py b/dab_deformable_detr/ops/build/lib.linux-x86_64-3.7/functions/__init__.py new file mode 100644 index 0000000..8a2197b --- /dev/null +++ b/dab_deformable_detr/ops/build/lib.linux-x86_64-3.7/functions/__init__.py @@ -0,0 +1,10 @@ +# ------------------------------------------------------------------------------------------------ +# Deformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------------------------------ +# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +# ------------------------------------------------------------------------------------------------ + +from .ms_deform_attn_func import MSDeformAttnFunction + diff --git a/dab_deformable_detr/ops/build/lib.linux-x86_64-3.7/functions/ms_deform_attn_func.py b/dab_deformable_detr/ops/build/lib.linux-x86_64-3.7/functions/ms_deform_attn_func.py new file mode 100644 index 0000000..8c5df8c --- /dev/null +++ b/dab_deformable_detr/ops/build/lib.linux-x86_64-3.7/functions/ms_deform_attn_func.py @@ -0,0 +1,61 @@ +# ------------------------------------------------------------------------------------------------ +# Deformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------------------------------ +# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +# ------------------------------------------------------------------------------------------------ + +from __future__ import absolute_import +from __future__ import print_function +from __future__ import division + +import torch +import torch.nn.functional as F +from torch.autograd import Function +from torch.autograd.function import once_differentiable + +import MultiScaleDeformableAttention as MSDA + + +class MSDeformAttnFunction(Function): + @staticmethod + def forward(ctx, value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, im2col_step): + ctx.im2col_step = im2col_step + output = MSDA.ms_deform_attn_forward( + value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, ctx.im2col_step) + ctx.save_for_backward(value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights) + return output + + @staticmethod + @once_differentiable + def backward(ctx, grad_output): + value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights = ctx.saved_tensors + grad_value, grad_sampling_loc, grad_attn_weight = \ + MSDA.ms_deform_attn_backward( + value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, grad_output, ctx.im2col_step) + + return grad_value, None, None, grad_sampling_loc, grad_attn_weight, None + + +def ms_deform_attn_core_pytorch(value, value_spatial_shapes, sampling_locations, attention_weights): + # for debug and test only, + # need to use cuda version instead + N_, S_, M_, D_ = value.shape + _, Lq_, M_, L_, P_, _ = sampling_locations.shape + value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes], dim=1) + sampling_grids = 2 * sampling_locations - 1 + sampling_value_list = [] + for lid_, (H_, W_) in enumerate(value_spatial_shapes): + # N_, H_*W_, M_, D_ -> N_, H_*W_, M_*D_ -> N_, M_*D_, H_*W_ -> N_*M_, D_, H_, W_ + value_l_ = value_list[lid_].flatten(2).transpose(1, 2).reshape(N_*M_, D_, H_, W_) + # N_, Lq_, M_, P_, 2 -> N_, M_, Lq_, P_, 2 -> N_*M_, Lq_, P_, 2 + sampling_grid_l_ = sampling_grids[:, :, :, lid_].transpose(1, 2).flatten(0, 1) + # N_*M_, D_, Lq_, P_ + sampling_value_l_ = F.grid_sample(value_l_, sampling_grid_l_, + mode='bilinear', padding_mode='zeros', align_corners=False) + sampling_value_list.append(sampling_value_l_) + # (N_, Lq_, M_, L_, P_) -> (N_, M_, Lq_, L_, P_) -> (N_, M_, 1, Lq_, L_*P_) + attention_weights = attention_weights.transpose(1, 2).reshape(N_*M_, 1, Lq_, L_*P_) + output = (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights).sum(-1).view(N_, M_*D_, Lq_) + return output.transpose(1, 2).contiguous() diff --git a/dab_deformable_detr/ops/build/lib.linux-x86_64-3.7/modules/__init__.py b/dab_deformable_detr/ops/build/lib.linux-x86_64-3.7/modules/__init__.py new file mode 100644 index 0000000..f82cb1a --- /dev/null +++ b/dab_deformable_detr/ops/build/lib.linux-x86_64-3.7/modules/__init__.py @@ -0,0 +1,9 @@ +# ------------------------------------------------------------------------------------------------ +# Deformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------------------------------ +# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +# ------------------------------------------------------------------------------------------------ + +from .ms_deform_attn import MSDeformAttn diff --git a/dab_deformable_detr/ops/build/lib.linux-x86_64-3.7/modules/ms_deform_attn.py b/dab_deformable_detr/ops/build/lib.linux-x86_64-3.7/modules/ms_deform_attn.py new file mode 100644 index 0000000..663d64a --- /dev/null +++ b/dab_deformable_detr/ops/build/lib.linux-x86_64-3.7/modules/ms_deform_attn.py @@ -0,0 +1,115 @@ +# ------------------------------------------------------------------------------------------------ +# Deformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------------------------------ +# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +# ------------------------------------------------------------------------------------------------ + +from __future__ import absolute_import +from __future__ import print_function +from __future__ import division + +import warnings +import math + +import torch +from torch import nn +import torch.nn.functional as F +from torch.nn.init import xavier_uniform_, constant_ + +from ..functions import MSDeformAttnFunction + + +def _is_power_of_2(n): + if (not isinstance(n, int)) or (n < 0): + raise ValueError("invalid input for _is_power_of_2: {} (type: {})".format(n, type(n))) + return (n & (n-1) == 0) and n != 0 + + +class MSDeformAttn(nn.Module): + def __init__(self, d_model=256, n_levels=4, n_heads=8, n_points=4): + """ + Multi-Scale Deformable Attention Module + :param d_model hidden dimension + :param n_levels number of feature levels + :param n_heads number of attention heads + :param n_points number of sampling points per attention head per feature level + """ + super().__init__() + if d_model % n_heads != 0: + raise ValueError('d_model must be divisible by n_heads, but got {} and {}'.format(d_model, n_heads)) + _d_per_head = d_model // n_heads + # you'd better set _d_per_head to a power of 2 which is more efficient in our CUDA implementation + if not _is_power_of_2(_d_per_head): + warnings.warn("You'd better set d_model in MSDeformAttn to make the dimension of each attention head a power of 2 " + "which is more efficient in our CUDA implementation.") + + self.im2col_step = 64 + + self.d_model = d_model + self.n_levels = n_levels + self.n_heads = n_heads + self.n_points = n_points + + self.sampling_offsets = nn.Linear(d_model, n_heads * n_levels * n_points * 2) + self.attention_weights = nn.Linear(d_model, n_heads * n_levels * n_points) + self.value_proj = nn.Linear(d_model, d_model) + self.output_proj = nn.Linear(d_model, d_model) + + self._reset_parameters() + + def _reset_parameters(self): + constant_(self.sampling_offsets.weight.data, 0.) + thetas = torch.arange(self.n_heads, dtype=torch.float32) * (2.0 * math.pi / self.n_heads) + grid_init = torch.stack([thetas.cos(), thetas.sin()], -1) + grid_init = (grid_init / grid_init.abs().max(-1, keepdim=True)[0]).view(self.n_heads, 1, 1, 2).repeat(1, self.n_levels, self.n_points, 1) + for i in range(self.n_points): + grid_init[:, :, i, :] *= i + 1 + with torch.no_grad(): + self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1)) + constant_(self.attention_weights.weight.data, 0.) + constant_(self.attention_weights.bias.data, 0.) + xavier_uniform_(self.value_proj.weight.data) + constant_(self.value_proj.bias.data, 0.) + xavier_uniform_(self.output_proj.weight.data) + constant_(self.output_proj.bias.data, 0.) + + def forward(self, query, reference_points, input_flatten, input_spatial_shapes, input_level_start_index, input_padding_mask=None): + """ + :param query (N, Length_{query}, C) + :param reference_points (N, Length_{query}, n_levels, 2), range in [0, 1], top-left (0,0), bottom-right (1, 1), including padding area + or (N, Length_{query}, n_levels, 4), add additional (w, h) to form reference boxes + :param input_flatten (N, \sum_{l=0}^{L-1} H_l \cdot W_l, C) + :param input_spatial_shapes (n_levels, 2), [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})] + :param input_level_start_index (n_levels, ), [0, H_0*W_0, H_0*W_0+H_1*W_1, H_0*W_0+H_1*W_1+H_2*W_2, ..., H_0*W_0+H_1*W_1+...+H_{L-1}*W_{L-1}] + :param input_padding_mask (N, \sum_{l=0}^{L-1} H_l \cdot W_l), True for padding elements, False for non-padding elements + + :return output (N, Length_{query}, C) + """ + N, Len_q, _ = query.shape + N, Len_in, _ = input_flatten.shape + assert (input_spatial_shapes[:, 0] * input_spatial_shapes[:, 1]).sum() == Len_in + + value = self.value_proj(input_flatten) + if input_padding_mask is not None: + value = value.masked_fill(input_padding_mask[..., None], float(0)) + value = value.view(N, Len_in, self.n_heads, self.d_model // self.n_heads) + sampling_offsets = self.sampling_offsets(query).view(N, Len_q, self.n_heads, self.n_levels, self.n_points, 2) + attention_weights = self.attention_weights(query).view(N, Len_q, self.n_heads, self.n_levels * self.n_points) + attention_weights = F.softmax(attention_weights, -1).view(N, Len_q, self.n_heads, self.n_levels, self.n_points) + # N, Len_q, n_heads, n_levels, n_points, 2 + if reference_points.shape[-1] == 2: + offset_normalizer = torch.stack([input_spatial_shapes[..., 1], input_spatial_shapes[..., 0]], -1) + sampling_locations = reference_points[:, :, None, :, None, :] \ + + sampling_offsets / offset_normalizer[None, None, None, :, None, :] + elif reference_points.shape[-1] == 4: + sampling_locations = reference_points[:, :, None, :, None, :2] \ + + sampling_offsets / self.n_points * reference_points[:, :, None, :, None, 2:] * 0.5 + else: + raise ValueError( + 'Last dim of reference_points must be 2 or 4, but get {} instead.'.format(reference_points.shape[-1])) + output = MSDeformAttnFunction.apply( + value, input_spatial_shapes, input_level_start_index, sampling_locations, attention_weights, self.im2col_step) + output = self.output_proj(output) + return output diff --git a/dab_deformable_detr/ops/build/lib.linux-x86_64-cpython-38/MultiScaleDeformableAttention.cpython-38-x86_64-linux-gnu.so b/dab_deformable_detr/ops/build/lib.linux-x86_64-cpython-38/MultiScaleDeformableAttention.cpython-38-x86_64-linux-gnu.so new file mode 100644 index 0000000..88d72aa Binary files /dev/null and b/dab_deformable_detr/ops/build/lib.linux-x86_64-cpython-38/MultiScaleDeformableAttention.cpython-38-x86_64-linux-gnu.so differ diff --git a/dab_deformable_detr/ops/build/lib.linux-x86_64-cpython-38/functions/__init__.py b/dab_deformable_detr/ops/build/lib.linux-x86_64-cpython-38/functions/__init__.py new file mode 100644 index 0000000..8a2197b --- /dev/null +++ b/dab_deformable_detr/ops/build/lib.linux-x86_64-cpython-38/functions/__init__.py @@ -0,0 +1,10 @@ +# ------------------------------------------------------------------------------------------------ +# Deformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------------------------------ +# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +# ------------------------------------------------------------------------------------------------ + +from .ms_deform_attn_func import MSDeformAttnFunction + diff --git a/dab_deformable_detr/ops/build/lib.linux-x86_64-cpython-38/functions/ms_deform_attn_func.py b/dab_deformable_detr/ops/build/lib.linux-x86_64-cpython-38/functions/ms_deform_attn_func.py new file mode 100644 index 0000000..8c5df8c --- /dev/null +++ b/dab_deformable_detr/ops/build/lib.linux-x86_64-cpython-38/functions/ms_deform_attn_func.py @@ -0,0 +1,61 @@ +# ------------------------------------------------------------------------------------------------ +# Deformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------------------------------ +# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +# ------------------------------------------------------------------------------------------------ + +from __future__ import absolute_import +from __future__ import print_function +from __future__ import division + +import torch +import torch.nn.functional as F +from torch.autograd import Function +from torch.autograd.function import once_differentiable + +import MultiScaleDeformableAttention as MSDA + + +class MSDeformAttnFunction(Function): + @staticmethod + def forward(ctx, value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, im2col_step): + ctx.im2col_step = im2col_step + output = MSDA.ms_deform_attn_forward( + value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, ctx.im2col_step) + ctx.save_for_backward(value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights) + return output + + @staticmethod + @once_differentiable + def backward(ctx, grad_output): + value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights = ctx.saved_tensors + grad_value, grad_sampling_loc, grad_attn_weight = \ + MSDA.ms_deform_attn_backward( + value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, grad_output, ctx.im2col_step) + + return grad_value, None, None, grad_sampling_loc, grad_attn_weight, None + + +def ms_deform_attn_core_pytorch(value, value_spatial_shapes, sampling_locations, attention_weights): + # for debug and test only, + # need to use cuda version instead + N_, S_, M_, D_ = value.shape + _, Lq_, M_, L_, P_, _ = sampling_locations.shape + value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes], dim=1) + sampling_grids = 2 * sampling_locations - 1 + sampling_value_list = [] + for lid_, (H_, W_) in enumerate(value_spatial_shapes): + # N_, H_*W_, M_, D_ -> N_, H_*W_, M_*D_ -> N_, M_*D_, H_*W_ -> N_*M_, D_, H_, W_ + value_l_ = value_list[lid_].flatten(2).transpose(1, 2).reshape(N_*M_, D_, H_, W_) + # N_, Lq_, M_, P_, 2 -> N_, M_, Lq_, P_, 2 -> N_*M_, Lq_, P_, 2 + sampling_grid_l_ = sampling_grids[:, :, :, lid_].transpose(1, 2).flatten(0, 1) + # N_*M_, D_, Lq_, P_ + sampling_value_l_ = F.grid_sample(value_l_, sampling_grid_l_, + mode='bilinear', padding_mode='zeros', align_corners=False) + sampling_value_list.append(sampling_value_l_) + # (N_, Lq_, M_, L_, P_) -> (N_, M_, Lq_, L_, P_) -> (N_, M_, 1, Lq_, L_*P_) + attention_weights = attention_weights.transpose(1, 2).reshape(N_*M_, 1, Lq_, L_*P_) + output = (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights).sum(-1).view(N_, M_*D_, Lq_) + return output.transpose(1, 2).contiguous() diff --git a/dab_deformable_detr/ops/build/lib.linux-x86_64-cpython-38/modules/__init__.py b/dab_deformable_detr/ops/build/lib.linux-x86_64-cpython-38/modules/__init__.py new file mode 100644 index 0000000..f82cb1a --- /dev/null +++ b/dab_deformable_detr/ops/build/lib.linux-x86_64-cpython-38/modules/__init__.py @@ -0,0 +1,9 @@ +# ------------------------------------------------------------------------------------------------ +# Deformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------------------------------ +# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +# ------------------------------------------------------------------------------------------------ + +from .ms_deform_attn import MSDeformAttn diff --git a/dab_deformable_detr/ops/build/lib.linux-x86_64-cpython-38/modules/ms_deform_attn.py b/dab_deformable_detr/ops/build/lib.linux-x86_64-cpython-38/modules/ms_deform_attn.py new file mode 100644 index 0000000..663d64a --- /dev/null +++ b/dab_deformable_detr/ops/build/lib.linux-x86_64-cpython-38/modules/ms_deform_attn.py @@ -0,0 +1,115 @@ +# ------------------------------------------------------------------------------------------------ +# Deformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------------------------------ +# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +# ------------------------------------------------------------------------------------------------ + +from __future__ import absolute_import +from __future__ import print_function +from __future__ import division + +import warnings +import math + +import torch +from torch import nn +import torch.nn.functional as F +from torch.nn.init import xavier_uniform_, constant_ + +from ..functions import MSDeformAttnFunction + + +def _is_power_of_2(n): + if (not isinstance(n, int)) or (n < 0): + raise ValueError("invalid input for _is_power_of_2: {} (type: {})".format(n, type(n))) + return (n & (n-1) == 0) and n != 0 + + +class MSDeformAttn(nn.Module): + def __init__(self, d_model=256, n_levels=4, n_heads=8, n_points=4): + """ + Multi-Scale Deformable Attention Module + :param d_model hidden dimension + :param n_levels number of feature levels + :param n_heads number of attention heads + :param n_points number of sampling points per attention head per feature level + """ + super().__init__() + if d_model % n_heads != 0: + raise ValueError('d_model must be divisible by n_heads, but got {} and {}'.format(d_model, n_heads)) + _d_per_head = d_model // n_heads + # you'd better set _d_per_head to a power of 2 which is more efficient in our CUDA implementation + if not _is_power_of_2(_d_per_head): + warnings.warn("You'd better set d_model in MSDeformAttn to make the dimension of each attention head a power of 2 " + "which is more efficient in our CUDA implementation.") + + self.im2col_step = 64 + + self.d_model = d_model + self.n_levels = n_levels + self.n_heads = n_heads + self.n_points = n_points + + self.sampling_offsets = nn.Linear(d_model, n_heads * n_levels * n_points * 2) + self.attention_weights = nn.Linear(d_model, n_heads * n_levels * n_points) + self.value_proj = nn.Linear(d_model, d_model) + self.output_proj = nn.Linear(d_model, d_model) + + self._reset_parameters() + + def _reset_parameters(self): + constant_(self.sampling_offsets.weight.data, 0.) + thetas = torch.arange(self.n_heads, dtype=torch.float32) * (2.0 * math.pi / self.n_heads) + grid_init = torch.stack([thetas.cos(), thetas.sin()], -1) + grid_init = (grid_init / grid_init.abs().max(-1, keepdim=True)[0]).view(self.n_heads, 1, 1, 2).repeat(1, self.n_levels, self.n_points, 1) + for i in range(self.n_points): + grid_init[:, :, i, :] *= i + 1 + with torch.no_grad(): + self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1)) + constant_(self.attention_weights.weight.data, 0.) + constant_(self.attention_weights.bias.data, 0.) + xavier_uniform_(self.value_proj.weight.data) + constant_(self.value_proj.bias.data, 0.) + xavier_uniform_(self.output_proj.weight.data) + constant_(self.output_proj.bias.data, 0.) + + def forward(self, query, reference_points, input_flatten, input_spatial_shapes, input_level_start_index, input_padding_mask=None): + """ + :param query (N, Length_{query}, C) + :param reference_points (N, Length_{query}, n_levels, 2), range in [0, 1], top-left (0,0), bottom-right (1, 1), including padding area + or (N, Length_{query}, n_levels, 4), add additional (w, h) to form reference boxes + :param input_flatten (N, \sum_{l=0}^{L-1} H_l \cdot W_l, C) + :param input_spatial_shapes (n_levels, 2), [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})] + :param input_level_start_index (n_levels, ), [0, H_0*W_0, H_0*W_0+H_1*W_1, H_0*W_0+H_1*W_1+H_2*W_2, ..., H_0*W_0+H_1*W_1+...+H_{L-1}*W_{L-1}] + :param input_padding_mask (N, \sum_{l=0}^{L-1} H_l \cdot W_l), True for padding elements, False for non-padding elements + + :return output (N, Length_{query}, C) + """ + N, Len_q, _ = query.shape + N, Len_in, _ = input_flatten.shape + assert (input_spatial_shapes[:, 0] * input_spatial_shapes[:, 1]).sum() == Len_in + + value = self.value_proj(input_flatten) + if input_padding_mask is not None: + value = value.masked_fill(input_padding_mask[..., None], float(0)) + value = value.view(N, Len_in, self.n_heads, self.d_model // self.n_heads) + sampling_offsets = self.sampling_offsets(query).view(N, Len_q, self.n_heads, self.n_levels, self.n_points, 2) + attention_weights = self.attention_weights(query).view(N, Len_q, self.n_heads, self.n_levels * self.n_points) + attention_weights = F.softmax(attention_weights, -1).view(N, Len_q, self.n_heads, self.n_levels, self.n_points) + # N, Len_q, n_heads, n_levels, n_points, 2 + if reference_points.shape[-1] == 2: + offset_normalizer = torch.stack([input_spatial_shapes[..., 1], input_spatial_shapes[..., 0]], -1) + sampling_locations = reference_points[:, :, None, :, None, :] \ + + sampling_offsets / offset_normalizer[None, None, None, :, None, :] + elif reference_points.shape[-1] == 4: + sampling_locations = reference_points[:, :, None, :, None, :2] \ + + sampling_offsets / self.n_points * reference_points[:, :, None, :, None, 2:] * 0.5 + else: + raise ValueError( + 'Last dim of reference_points must be 2 or 4, but get {} instead.'.format(reference_points.shape[-1])) + output = MSDeformAttnFunction.apply( + value, input_spatial_shapes, input_level_start_index, sampling_locations, attention_weights, self.im2col_step) + output = self.output_proj(output) + return output diff --git a/dab_deformable_detr/ops/build/temp.linux-x86_64-3.7/data2/jiangchanglong/a2jformer/code3/dab_deformable_detr/ops/src/cpu/ms_deform_attn_cpu.o b/dab_deformable_detr/ops/build/temp.linux-x86_64-3.7/data2/jiangchanglong/a2jformer/code3/dab_deformable_detr/ops/src/cpu/ms_deform_attn_cpu.o new file mode 100644 index 0000000..59e7f29 Binary files /dev/null and b/dab_deformable_detr/ops/build/temp.linux-x86_64-3.7/data2/jiangchanglong/a2jformer/code3/dab_deformable_detr/ops/src/cpu/ms_deform_attn_cpu.o differ diff --git a/dab_deformable_detr/ops/build/temp.linux-x86_64-3.7/data2/jiangchanglong/a2jformer/code3/dab_deformable_detr/ops/src/cuda/ms_deform_attn_cuda.o b/dab_deformable_detr/ops/build/temp.linux-x86_64-3.7/data2/jiangchanglong/a2jformer/code3/dab_deformable_detr/ops/src/cuda/ms_deform_attn_cuda.o new file mode 100644 index 0000000..537bb68 Binary files /dev/null and b/dab_deformable_detr/ops/build/temp.linux-x86_64-3.7/data2/jiangchanglong/a2jformer/code3/dab_deformable_detr/ops/src/cuda/ms_deform_attn_cuda.o differ diff --git a/dab_deformable_detr/ops/build/temp.linux-x86_64-3.7/data2/jiangchanglong/a2jformer/code3/dab_deformable_detr/ops/src/vision.o b/dab_deformable_detr/ops/build/temp.linux-x86_64-3.7/data2/jiangchanglong/a2jformer/code3/dab_deformable_detr/ops/src/vision.o new file mode 100644 index 0000000..a04bc23 Binary files /dev/null and b/dab_deformable_detr/ops/build/temp.linux-x86_64-3.7/data2/jiangchanglong/a2jformer/code3/dab_deformable_detr/ops/src/vision.o differ diff --git a/dab_deformable_detr/ops/build/temp.linux-x86_64-cpython-38/.ninja_deps b/dab_deformable_detr/ops/build/temp.linux-x86_64-cpython-38/.ninja_deps new file mode 100644 index 0000000..845e8c1 Binary files /dev/null and b/dab_deformable_detr/ops/build/temp.linux-x86_64-cpython-38/.ninja_deps differ diff --git a/dab_deformable_detr/ops/build/temp.linux-x86_64-cpython-38/.ninja_log b/dab_deformable_detr/ops/build/temp.linux-x86_64-cpython-38/.ninja_log new file mode 100644 index 0000000..fe8eb4a --- /dev/null +++ b/dab_deformable_detr/ops/build/temp.linux-x86_64-cpython-38/.ninja_log @@ -0,0 +1,7 @@ +# ninja log v5 +0 4331 1679139496513782543 /data/data2/a2jformer/camera_ready/dab_deformable_detr/ops/build/temp.linux-x86_64-cpython-38/data/data2/a2jformer/camera_ready/dab_deformable_detr/ops/src/cpu/ms_deform_attn_cpu.o e5275a58a972aad7 +1 12199 1679139504369771995 /data/data2/a2jformer/camera_ready/dab_deformable_detr/ops/build/temp.linux-x86_64-cpython-38/data/data2/a2jformer/camera_ready/dab_deformable_detr/ops/src/cuda/ms_deform_attn_cuda.o 4d63f9a8b5055438 +1 24178 1679139516345755838 /data/data2/a2jformer/camera_ready/dab_deformable_detr/ops/build/temp.linux-x86_64-cpython-38/data/data2/a2jformer/camera_ready/dab_deformable_detr/ops/src/vision.o 88dd7ad6ddf5cffd +3 4000 1679141555962982953 /data/data2/a2jformer/camera_ready/dab_deformable_detr/ops/build/temp.linux-x86_64-cpython-38/data/data2/a2jformer/camera_ready/dab_deformable_detr/ops/src/cpu/ms_deform_attn_cpu.o e21fd8c225d636f5 +4 11334 1679141563286994075 /data/data2/a2jformer/camera_ready/dab_deformable_detr/ops/build/temp.linux-x86_64-cpython-38/data/data2/a2jformer/camera_ready/dab_deformable_detr/ops/src/cuda/ms_deform_attn_cuda.o 92772fe5bb9d14 +4 22547 1679141574499010908 /data/data2/a2jformer/camera_ready/dab_deformable_detr/ops/build/temp.linux-x86_64-cpython-38/data/data2/a2jformer/camera_ready/dab_deformable_detr/ops/src/vision.o 3f829870476a0f61 diff --git a/dab_deformable_detr/ops/build/temp.linux-x86_64-cpython-38/build.ninja b/dab_deformable_detr/ops/build/temp.linux-x86_64-cpython-38/build.ninja new file mode 100644 index 0000000..bf84e92 --- /dev/null +++ b/dab_deformable_detr/ops/build/temp.linux-x86_64-cpython-38/build.ninja @@ -0,0 +1,28 @@ +ninja_required_version = 1.3 +cxx = c++ +nvcc = /usr/local/cuda-11.0/bin/nvcc + +cflags = -pthread -B /home/wucunlin/anaconda3/envs/py38/compiler_compat -Wl,--sysroot=/ -Wsign-compare -DNDEBUG -g -fwrapv -O3 -Wall -Wstrict-prototypes -fPIC -DWITH_CUDA -I/data/data2/a2jformer/camera_ready/dab_deformable_detr/ops/src -I/home/wucunlin/anaconda3/envs/py38/lib/python3.8/site-packages/torch/include -I/home/wucunlin/anaconda3/envs/py38/lib/python3.8/site-packages/torch/include/torch/csrc/api/include -I/home/wucunlin/anaconda3/envs/py38/lib/python3.8/site-packages/torch/include/TH -I/home/wucunlin/anaconda3/envs/py38/lib/python3.8/site-packages/torch/include/THC -I/usr/local/cuda-11.0/include -I/home/wucunlin/anaconda3/envs/py38/include/python3.8 -c +post_cflags = -DTORCH_API_INCLUDE_EXTENSION_H -DTORCH_EXTENSION_NAME=MultiScaleDeformableAttention -D_GLIBCXX_USE_CXX11_ABI=0 -std=c++14 +cuda_cflags = -DWITH_CUDA -I/data/data2/a2jformer/camera_ready/dab_deformable_detr/ops/src -I/home/wucunlin/anaconda3/envs/py38/lib/python3.8/site-packages/torch/include -I/home/wucunlin/anaconda3/envs/py38/lib/python3.8/site-packages/torch/include/torch/csrc/api/include -I/home/wucunlin/anaconda3/envs/py38/lib/python3.8/site-packages/torch/include/TH -I/home/wucunlin/anaconda3/envs/py38/lib/python3.8/site-packages/torch/include/THC -I/usr/local/cuda-11.0/include -I/home/wucunlin/anaconda3/envs/py38/include/python3.8 -c +cuda_post_cflags = -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr --compiler-options ''"'"'-fPIC'"'"'' -DCUDA_HAS_FP16=1 -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ -DTORCH_API_INCLUDE_EXTENSION_H -DTORCH_EXTENSION_NAME=MultiScaleDeformableAttention -D_GLIBCXX_USE_CXX11_ABI=0 -gencode=arch=compute_80,code=sm_80 -std=c++14 +ldflags = + +rule compile + command = $cxx -MMD -MF $out.d $cflags -c $in -o $out $post_cflags + depfile = $out.d + deps = gcc + +rule cuda_compile + command = $nvcc $cuda_cflags -c $in -o $out $cuda_post_cflags + + + +build /data/data2/a2jformer/camera_ready/dab_deformable_detr/ops/build/temp.linux-x86_64-cpython-38/data/data2/a2jformer/camera_ready/dab_deformable_detr/ops/src/cpu/ms_deform_attn_cpu.o: compile /data/data2/a2jformer/camera_ready/dab_deformable_detr/ops/src/cpu/ms_deform_attn_cpu.cpp +build /data/data2/a2jformer/camera_ready/dab_deformable_detr/ops/build/temp.linux-x86_64-cpython-38/data/data2/a2jformer/camera_ready/dab_deformable_detr/ops/src/cuda/ms_deform_attn_cuda.o: cuda_compile /data/data2/a2jformer/camera_ready/dab_deformable_detr/ops/src/cuda/ms_deform_attn_cuda.cu +build /data/data2/a2jformer/camera_ready/dab_deformable_detr/ops/build/temp.linux-x86_64-cpython-38/data/data2/a2jformer/camera_ready/dab_deformable_detr/ops/src/vision.o: compile /data/data2/a2jformer/camera_ready/dab_deformable_detr/ops/src/vision.cpp + + + + + diff --git a/dab_deformable_detr/ops/build/temp.linux-x86_64-cpython-38/data/data2/a2jformer/camera_ready/dab_deformable_detr/ops/src/cpu/ms_deform_attn_cpu.o b/dab_deformable_detr/ops/build/temp.linux-x86_64-cpython-38/data/data2/a2jformer/camera_ready/dab_deformable_detr/ops/src/cpu/ms_deform_attn_cpu.o new file mode 100644 index 0000000..db5bee4 Binary files /dev/null and b/dab_deformable_detr/ops/build/temp.linux-x86_64-cpython-38/data/data2/a2jformer/camera_ready/dab_deformable_detr/ops/src/cpu/ms_deform_attn_cpu.o differ diff --git a/dab_deformable_detr/ops/build/temp.linux-x86_64-cpython-38/data/data2/a2jformer/camera_ready/dab_deformable_detr/ops/src/cuda/ms_deform_attn_cuda.o b/dab_deformable_detr/ops/build/temp.linux-x86_64-cpython-38/data/data2/a2jformer/camera_ready/dab_deformable_detr/ops/src/cuda/ms_deform_attn_cuda.o new file mode 100644 index 0000000..a8c8754 Binary files /dev/null and b/dab_deformable_detr/ops/build/temp.linux-x86_64-cpython-38/data/data2/a2jformer/camera_ready/dab_deformable_detr/ops/src/cuda/ms_deform_attn_cuda.o differ diff --git a/dab_deformable_detr/ops/build/temp.linux-x86_64-cpython-38/data/data2/a2jformer/camera_ready/dab_deformable_detr/ops/src/vision.o b/dab_deformable_detr/ops/build/temp.linux-x86_64-cpython-38/data/data2/a2jformer/camera_ready/dab_deformable_detr/ops/src/vision.o new file mode 100644 index 0000000..6365630 Binary files /dev/null and b/dab_deformable_detr/ops/build/temp.linux-x86_64-cpython-38/data/data2/a2jformer/camera_ready/dab_deformable_detr/ops/src/vision.o differ diff --git a/dab_deformable_detr/ops/dist/MultiScaleDeformableAttention-1.0-py3.7-linux-x86_64.egg b/dab_deformable_detr/ops/dist/MultiScaleDeformableAttention-1.0-py3.7-linux-x86_64.egg new file mode 100644 index 0000000..63c71d7 Binary files /dev/null and b/dab_deformable_detr/ops/dist/MultiScaleDeformableAttention-1.0-py3.7-linux-x86_64.egg differ diff --git a/dab_deformable_detr/ops/dist/MultiScaleDeformableAttention-1.0-py3.8-linux-x86_64.egg b/dab_deformable_detr/ops/dist/MultiScaleDeformableAttention-1.0-py3.8-linux-x86_64.egg new file mode 100644 index 0000000..e2aff5e Binary files /dev/null and b/dab_deformable_detr/ops/dist/MultiScaleDeformableAttention-1.0-py3.8-linux-x86_64.egg differ diff --git a/dab_deformable_detr/ops/functions/__init__.py b/dab_deformable_detr/ops/functions/__init__.py new file mode 100644 index 0000000..8a2197b --- /dev/null +++ b/dab_deformable_detr/ops/functions/__init__.py @@ -0,0 +1,10 @@ +# ------------------------------------------------------------------------------------------------ +# Deformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------------------------------ +# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +# ------------------------------------------------------------------------------------------------ + +from .ms_deform_attn_func import MSDeformAttnFunction + diff --git a/dab_deformable_detr/ops/functions/__pycache__/__init__.cpython-37.pyc b/dab_deformable_detr/ops/functions/__pycache__/__init__.cpython-37.pyc new file mode 100644 index 0000000..dcfdaa8 Binary files /dev/null and b/dab_deformable_detr/ops/functions/__pycache__/__init__.cpython-37.pyc differ diff --git a/dab_deformable_detr/ops/functions/__pycache__/__init__.cpython-38.pyc b/dab_deformable_detr/ops/functions/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000..f922c9a Binary files /dev/null and b/dab_deformable_detr/ops/functions/__pycache__/__init__.cpython-38.pyc differ diff --git a/dab_deformable_detr/ops/functions/__pycache__/ms_deform_attn_func.cpython-37.pyc b/dab_deformable_detr/ops/functions/__pycache__/ms_deform_attn_func.cpython-37.pyc new file mode 100644 index 0000000..d6e6714 Binary files /dev/null and b/dab_deformable_detr/ops/functions/__pycache__/ms_deform_attn_func.cpython-37.pyc differ diff --git a/dab_deformable_detr/ops/functions/__pycache__/ms_deform_attn_func.cpython-38.pyc b/dab_deformable_detr/ops/functions/__pycache__/ms_deform_attn_func.cpython-38.pyc new file mode 100644 index 0000000..f306a3d Binary files /dev/null and b/dab_deformable_detr/ops/functions/__pycache__/ms_deform_attn_func.cpython-38.pyc differ diff --git a/dab_deformable_detr/ops/functions/ms_deform_attn_func.py b/dab_deformable_detr/ops/functions/ms_deform_attn_func.py new file mode 100644 index 0000000..8c5df8c --- /dev/null +++ b/dab_deformable_detr/ops/functions/ms_deform_attn_func.py @@ -0,0 +1,61 @@ +# ------------------------------------------------------------------------------------------------ +# Deformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------------------------------ +# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +# ------------------------------------------------------------------------------------------------ + +from __future__ import absolute_import +from __future__ import print_function +from __future__ import division + +import torch +import torch.nn.functional as F +from torch.autograd import Function +from torch.autograd.function import once_differentiable + +import MultiScaleDeformableAttention as MSDA + + +class MSDeformAttnFunction(Function): + @staticmethod + def forward(ctx, value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, im2col_step): + ctx.im2col_step = im2col_step + output = MSDA.ms_deform_attn_forward( + value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, ctx.im2col_step) + ctx.save_for_backward(value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights) + return output + + @staticmethod + @once_differentiable + def backward(ctx, grad_output): + value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights = ctx.saved_tensors + grad_value, grad_sampling_loc, grad_attn_weight = \ + MSDA.ms_deform_attn_backward( + value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, grad_output, ctx.im2col_step) + + return grad_value, None, None, grad_sampling_loc, grad_attn_weight, None + + +def ms_deform_attn_core_pytorch(value, value_spatial_shapes, sampling_locations, attention_weights): + # for debug and test only, + # need to use cuda version instead + N_, S_, M_, D_ = value.shape + _, Lq_, M_, L_, P_, _ = sampling_locations.shape + value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes], dim=1) + sampling_grids = 2 * sampling_locations - 1 + sampling_value_list = [] + for lid_, (H_, W_) in enumerate(value_spatial_shapes): + # N_, H_*W_, M_, D_ -> N_, H_*W_, M_*D_ -> N_, M_*D_, H_*W_ -> N_*M_, D_, H_, W_ + value_l_ = value_list[lid_].flatten(2).transpose(1, 2).reshape(N_*M_, D_, H_, W_) + # N_, Lq_, M_, P_, 2 -> N_, M_, Lq_, P_, 2 -> N_*M_, Lq_, P_, 2 + sampling_grid_l_ = sampling_grids[:, :, :, lid_].transpose(1, 2).flatten(0, 1) + # N_*M_, D_, Lq_, P_ + sampling_value_l_ = F.grid_sample(value_l_, sampling_grid_l_, + mode='bilinear', padding_mode='zeros', align_corners=False) + sampling_value_list.append(sampling_value_l_) + # (N_, Lq_, M_, L_, P_) -> (N_, M_, Lq_, L_, P_) -> (N_, M_, 1, Lq_, L_*P_) + attention_weights = attention_weights.transpose(1, 2).reshape(N_*M_, 1, Lq_, L_*P_) + output = (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights).sum(-1).view(N_, M_*D_, Lq_) + return output.transpose(1, 2).contiguous() diff --git a/dab_deformable_detr/ops/make.sh b/dab_deformable_detr/ops/make.sh new file mode 100644 index 0000000..eae08d0 --- /dev/null +++ b/dab_deformable_detr/ops/make.sh @@ -0,0 +1,12 @@ +#!/usr/bin/env bash +# ------------------------------------------------------------------------------------------------ +# Deformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------------------------------ +# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +# ------------------------------------------------------------------------------------------------ + + +# TORCH_CUDA_ARCH_LIST="8.0" CUDA_HOME='/path/to/your/cuda/dir' +python setup.py build install diff --git a/dab_deformable_detr/ops/modules/__init__.py b/dab_deformable_detr/ops/modules/__init__.py new file mode 100644 index 0000000..f82cb1a --- /dev/null +++ b/dab_deformable_detr/ops/modules/__init__.py @@ -0,0 +1,9 @@ +# ------------------------------------------------------------------------------------------------ +# Deformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------------------------------ +# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +# ------------------------------------------------------------------------------------------------ + +from .ms_deform_attn import MSDeformAttn diff --git a/dab_deformable_detr/ops/modules/__pycache__/__init__.cpython-37.pyc b/dab_deformable_detr/ops/modules/__pycache__/__init__.cpython-37.pyc new file mode 100644 index 0000000..ce62655 Binary files /dev/null and b/dab_deformable_detr/ops/modules/__pycache__/__init__.cpython-37.pyc differ diff --git a/dab_deformable_detr/ops/modules/__pycache__/__init__.cpython-38.pyc b/dab_deformable_detr/ops/modules/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000..c85e447 Binary files /dev/null and b/dab_deformable_detr/ops/modules/__pycache__/__init__.cpython-38.pyc differ diff --git a/dab_deformable_detr/ops/modules/__pycache__/ms_deform_attn.cpython-37.pyc b/dab_deformable_detr/ops/modules/__pycache__/ms_deform_attn.cpython-37.pyc new file mode 100644 index 0000000..9c8c2f3 Binary files /dev/null and b/dab_deformable_detr/ops/modules/__pycache__/ms_deform_attn.cpython-37.pyc differ diff --git a/dab_deformable_detr/ops/modules/__pycache__/ms_deform_attn.cpython-38.pyc b/dab_deformable_detr/ops/modules/__pycache__/ms_deform_attn.cpython-38.pyc new file mode 100644 index 0000000..fa7ab1a Binary files /dev/null and b/dab_deformable_detr/ops/modules/__pycache__/ms_deform_attn.cpython-38.pyc differ diff --git a/dab_deformable_detr/ops/modules/ms_deform_attn.py b/dab_deformable_detr/ops/modules/ms_deform_attn.py new file mode 100644 index 0000000..663d64a --- /dev/null +++ b/dab_deformable_detr/ops/modules/ms_deform_attn.py @@ -0,0 +1,115 @@ +# ------------------------------------------------------------------------------------------------ +# Deformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------------------------------ +# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +# ------------------------------------------------------------------------------------------------ + +from __future__ import absolute_import +from __future__ import print_function +from __future__ import division + +import warnings +import math + +import torch +from torch import nn +import torch.nn.functional as F +from torch.nn.init import xavier_uniform_, constant_ + +from ..functions import MSDeformAttnFunction + + +def _is_power_of_2(n): + if (not isinstance(n, int)) or (n < 0): + raise ValueError("invalid input for _is_power_of_2: {} (type: {})".format(n, type(n))) + return (n & (n-1) == 0) and n != 0 + + +class MSDeformAttn(nn.Module): + def __init__(self, d_model=256, n_levels=4, n_heads=8, n_points=4): + """ + Multi-Scale Deformable Attention Module + :param d_model hidden dimension + :param n_levels number of feature levels + :param n_heads number of attention heads + :param n_points number of sampling points per attention head per feature level + """ + super().__init__() + if d_model % n_heads != 0: + raise ValueError('d_model must be divisible by n_heads, but got {} and {}'.format(d_model, n_heads)) + _d_per_head = d_model // n_heads + # you'd better set _d_per_head to a power of 2 which is more efficient in our CUDA implementation + if not _is_power_of_2(_d_per_head): + warnings.warn("You'd better set d_model in MSDeformAttn to make the dimension of each attention head a power of 2 " + "which is more efficient in our CUDA implementation.") + + self.im2col_step = 64 + + self.d_model = d_model + self.n_levels = n_levels + self.n_heads = n_heads + self.n_points = n_points + + self.sampling_offsets = nn.Linear(d_model, n_heads * n_levels * n_points * 2) + self.attention_weights = nn.Linear(d_model, n_heads * n_levels * n_points) + self.value_proj = nn.Linear(d_model, d_model) + self.output_proj = nn.Linear(d_model, d_model) + + self._reset_parameters() + + def _reset_parameters(self): + constant_(self.sampling_offsets.weight.data, 0.) + thetas = torch.arange(self.n_heads, dtype=torch.float32) * (2.0 * math.pi / self.n_heads) + grid_init = torch.stack([thetas.cos(), thetas.sin()], -1) + grid_init = (grid_init / grid_init.abs().max(-1, keepdim=True)[0]).view(self.n_heads, 1, 1, 2).repeat(1, self.n_levels, self.n_points, 1) + for i in range(self.n_points): + grid_init[:, :, i, :] *= i + 1 + with torch.no_grad(): + self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1)) + constant_(self.attention_weights.weight.data, 0.) + constant_(self.attention_weights.bias.data, 0.) + xavier_uniform_(self.value_proj.weight.data) + constant_(self.value_proj.bias.data, 0.) + xavier_uniform_(self.output_proj.weight.data) + constant_(self.output_proj.bias.data, 0.) + + def forward(self, query, reference_points, input_flatten, input_spatial_shapes, input_level_start_index, input_padding_mask=None): + """ + :param query (N, Length_{query}, C) + :param reference_points (N, Length_{query}, n_levels, 2), range in [0, 1], top-left (0,0), bottom-right (1, 1), including padding area + or (N, Length_{query}, n_levels, 4), add additional (w, h) to form reference boxes + :param input_flatten (N, \sum_{l=0}^{L-1} H_l \cdot W_l, C) + :param input_spatial_shapes (n_levels, 2), [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})] + :param input_level_start_index (n_levels, ), [0, H_0*W_0, H_0*W_0+H_1*W_1, H_0*W_0+H_1*W_1+H_2*W_2, ..., H_0*W_0+H_1*W_1+...+H_{L-1}*W_{L-1}] + :param input_padding_mask (N, \sum_{l=0}^{L-1} H_l \cdot W_l), True for padding elements, False for non-padding elements + + :return output (N, Length_{query}, C) + """ + N, Len_q, _ = query.shape + N, Len_in, _ = input_flatten.shape + assert (input_spatial_shapes[:, 0] * input_spatial_shapes[:, 1]).sum() == Len_in + + value = self.value_proj(input_flatten) + if input_padding_mask is not None: + value = value.masked_fill(input_padding_mask[..., None], float(0)) + value = value.view(N, Len_in, self.n_heads, self.d_model // self.n_heads) + sampling_offsets = self.sampling_offsets(query).view(N, Len_q, self.n_heads, self.n_levels, self.n_points, 2) + attention_weights = self.attention_weights(query).view(N, Len_q, self.n_heads, self.n_levels * self.n_points) + attention_weights = F.softmax(attention_weights, -1).view(N, Len_q, self.n_heads, self.n_levels, self.n_points) + # N, Len_q, n_heads, n_levels, n_points, 2 + if reference_points.shape[-1] == 2: + offset_normalizer = torch.stack([input_spatial_shapes[..., 1], input_spatial_shapes[..., 0]], -1) + sampling_locations = reference_points[:, :, None, :, None, :] \ + + sampling_offsets / offset_normalizer[None, None, None, :, None, :] + elif reference_points.shape[-1] == 4: + sampling_locations = reference_points[:, :, None, :, None, :2] \ + + sampling_offsets / self.n_points * reference_points[:, :, None, :, None, 2:] * 0.5 + else: + raise ValueError( + 'Last dim of reference_points must be 2 or 4, but get {} instead.'.format(reference_points.shape[-1])) + output = MSDeformAttnFunction.apply( + value, input_spatial_shapes, input_level_start_index, sampling_locations, attention_weights, self.im2col_step) + output = self.output_proj(output) + return output diff --git a/dab_deformable_detr/ops/setup.py b/dab_deformable_detr/ops/setup.py new file mode 100644 index 0000000..049f923 --- /dev/null +++ b/dab_deformable_detr/ops/setup.py @@ -0,0 +1,73 @@ +# ------------------------------------------------------------------------------------------------ +# Deformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------------------------------ +# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +# ------------------------------------------------------------------------------------------------ + +import os +import glob + +import torch + +from torch.utils.cpp_extension import CUDA_HOME +from torch.utils.cpp_extension import CppExtension +from torch.utils.cpp_extension import CUDAExtension + +from setuptools import find_packages +from setuptools import setup + +requirements = ["torch", "torchvision"] + +def get_extensions(): + this_dir = os.path.dirname(os.path.abspath(__file__)) + extensions_dir = os.path.join(this_dir, "src") + + main_file = glob.glob(os.path.join(extensions_dir, "*.cpp")) + source_cpu = glob.glob(os.path.join(extensions_dir, "cpu", "*.cpp")) + source_cuda = glob.glob(os.path.join(extensions_dir, "cuda", "*.cu")) + + sources = main_file + source_cpu + extension = CppExtension + extra_compile_args = {"cxx": []} + define_macros = [] + + # import ipdb; ipdb.set_trace() + + if torch.cuda.is_available() and CUDA_HOME is not None: + extension = CUDAExtension + sources += source_cuda + define_macros += [("WITH_CUDA", None)] + extra_compile_args["nvcc"] = [ + "-DCUDA_HAS_FP16=1", + "-D__CUDA_NO_HALF_OPERATORS__", + "-D__CUDA_NO_HALF_CONVERSIONS__", + "-D__CUDA_NO_HALF2_OPERATORS__", + ] + else: + raise NotImplementedError('Cuda is not availabel') + + sources = [os.path.join(extensions_dir, s) for s in sources] + include_dirs = [extensions_dir] + ext_modules = [ + extension( + "MultiScaleDeformableAttention", + sources, + include_dirs=include_dirs, + define_macros=define_macros, + extra_compile_args=extra_compile_args, + ) + ] + return ext_modules + +setup( + name="MultiScaleDeformableAttention", + version="1.0", + author="Weijie Su", + url="https://github.com/fundamentalvision/Deformable-DETR", + description="PyTorch Wrapper for CUDA Functions of Multi-Scale Deformable Attention", + packages=find_packages(exclude=("configs", "tests",)), + ext_modules=get_extensions(), + cmdclass={"build_ext": torch.utils.cpp_extension.BuildExtension}, +) diff --git a/dab_deformable_detr/ops/src/cpu/ms_deform_attn_cpu.cpp b/dab_deformable_detr/ops/src/cpu/ms_deform_attn_cpu.cpp new file mode 100644 index 0000000..e1bf854 --- /dev/null +++ b/dab_deformable_detr/ops/src/cpu/ms_deform_attn_cpu.cpp @@ -0,0 +1,41 @@ +/*! +************************************************************************************************** +* Deformable DETR +* Copyright (c) 2020 SenseTime. All Rights Reserved. +* Licensed under the Apache License, Version 2.0 [see LICENSE for details] +************************************************************************************************** +* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +************************************************************************************************** +*/ + +#include + +#include +#include + + +at::Tensor +ms_deform_attn_cpu_forward( + const at::Tensor &value, + const at::Tensor &spatial_shapes, + const at::Tensor &level_start_index, + const at::Tensor &sampling_loc, + const at::Tensor &attn_weight, + const int im2col_step) +{ + AT_ERROR("Not implement on cpu"); +} + +std::vector +ms_deform_attn_cpu_backward( + const at::Tensor &value, + const at::Tensor &spatial_shapes, + const at::Tensor &level_start_index, + const at::Tensor &sampling_loc, + const at::Tensor &attn_weight, + const at::Tensor &grad_output, + const int im2col_step) +{ + AT_ERROR("Not implement on cpu"); +} + diff --git a/dab_deformable_detr/ops/src/cpu/ms_deform_attn_cpu.h b/dab_deformable_detr/ops/src/cpu/ms_deform_attn_cpu.h new file mode 100644 index 0000000..81b7b58 --- /dev/null +++ b/dab_deformable_detr/ops/src/cpu/ms_deform_attn_cpu.h @@ -0,0 +1,33 @@ +/*! +************************************************************************************************** +* Deformable DETR +* Copyright (c) 2020 SenseTime. All Rights Reserved. +* Licensed under the Apache License, Version 2.0 [see LICENSE for details] +************************************************************************************************** +* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +************************************************************************************************** +*/ + +#pragma once +#include + +at::Tensor +ms_deform_attn_cpu_forward( + const at::Tensor &value, + const at::Tensor &spatial_shapes, + const at::Tensor &level_start_index, + const at::Tensor &sampling_loc, + const at::Tensor &attn_weight, + const int im2col_step); + +std::vector +ms_deform_attn_cpu_backward( + const at::Tensor &value, + const at::Tensor &spatial_shapes, + const at::Tensor &level_start_index, + const at::Tensor &sampling_loc, + const at::Tensor &attn_weight, + const at::Tensor &grad_output, + const int im2col_step); + + diff --git a/dab_deformable_detr/ops/src/cuda/ms_deform_attn_cuda.cu b/dab_deformable_detr/ops/src/cuda/ms_deform_attn_cuda.cu new file mode 100644 index 0000000..d6d5836 --- /dev/null +++ b/dab_deformable_detr/ops/src/cuda/ms_deform_attn_cuda.cu @@ -0,0 +1,153 @@ +/*! +************************************************************************************************** +* Deformable DETR +* Copyright (c) 2020 SenseTime. All Rights Reserved. +* Licensed under the Apache License, Version 2.0 [see LICENSE for details] +************************************************************************************************** +* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +************************************************************************************************** +*/ + +#include +#include "cuda/ms_deform_im2col_cuda.cuh" + +#include +#include +#include +#include + + +at::Tensor ms_deform_attn_cuda_forward( + const at::Tensor &value, + const at::Tensor &spatial_shapes, + const at::Tensor &level_start_index, + const at::Tensor &sampling_loc, + const at::Tensor &attn_weight, + const int im2col_step) +{ + AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous"); + AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous"); + AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous"); + AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous"); + AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous"); + + AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor"); + AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor"); + AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor"); + AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor"); + AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor"); + + const int batch = value.size(0); + const int spatial_size = value.size(1); + const int num_heads = value.size(2); + const int channels = value.size(3); + + const int num_levels = spatial_shapes.size(0); + + const int num_query = sampling_loc.size(1); + const int num_point = sampling_loc.size(4); + + const int im2col_step_ = std::min(batch, im2col_step); + + AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_); + + auto output = at::zeros({batch, num_query, num_heads, channels}, value.options()); + + const int batch_n = im2col_step_; + auto output_n = output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels}); + auto per_value_size = spatial_size * num_heads * channels; + auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2; + auto per_attn_weight_size = num_query * num_heads * num_levels * num_point; + for (int n = 0; n < batch/im2col_step_; ++n) + { + auto columns = output_n.select(0, n); + AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_forward_cuda", ([&] { + ms_deformable_im2col_cuda(at::cuda::getCurrentCUDAStream(), + value.data() + n * im2col_step_ * per_value_size, + spatial_shapes.data(), + level_start_index.data(), + sampling_loc.data() + n * im2col_step_ * per_sample_loc_size, + attn_weight.data() + n * im2col_step_ * per_attn_weight_size, + batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point, + columns.data()); + + })); + } + + output = output.view({batch, num_query, num_heads*channels}); + + return output; +} + + +std::vector ms_deform_attn_cuda_backward( + const at::Tensor &value, + const at::Tensor &spatial_shapes, + const at::Tensor &level_start_index, + const at::Tensor &sampling_loc, + const at::Tensor &attn_weight, + const at::Tensor &grad_output, + const int im2col_step) +{ + + AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous"); + AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous"); + AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous"); + AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous"); + AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous"); + AT_ASSERTM(grad_output.is_contiguous(), "grad_output tensor has to be contiguous"); + + AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor"); + AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor"); + AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor"); + AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor"); + AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor"); + AT_ASSERTM(grad_output.type().is_cuda(), "grad_output must be a CUDA tensor"); + + const int batch = value.size(0); + const int spatial_size = value.size(1); + const int num_heads = value.size(2); + const int channels = value.size(3); + + const int num_levels = spatial_shapes.size(0); + + const int num_query = sampling_loc.size(1); + const int num_point = sampling_loc.size(4); + + const int im2col_step_ = std::min(batch, im2col_step); + + AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_); + + auto grad_value = at::zeros_like(value); + auto grad_sampling_loc = at::zeros_like(sampling_loc); + auto grad_attn_weight = at::zeros_like(attn_weight); + + const int batch_n = im2col_step_; + auto per_value_size = spatial_size * num_heads * channels; + auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2; + auto per_attn_weight_size = num_query * num_heads * num_levels * num_point; + auto grad_output_n = grad_output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels}); + + for (int n = 0; n < batch/im2col_step_; ++n) + { + auto grad_output_g = grad_output_n.select(0, n); + AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_backward_cuda", ([&] { + ms_deformable_col2im_cuda(at::cuda::getCurrentCUDAStream(), + grad_output_g.data(), + value.data() + n * im2col_step_ * per_value_size, + spatial_shapes.data(), + level_start_index.data(), + sampling_loc.data() + n * im2col_step_ * per_sample_loc_size, + attn_weight.data() + n * im2col_step_ * per_attn_weight_size, + batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point, + grad_value.data() + n * im2col_step_ * per_value_size, + grad_sampling_loc.data() + n * im2col_step_ * per_sample_loc_size, + grad_attn_weight.data() + n * im2col_step_ * per_attn_weight_size); + + })); + } + + return { + grad_value, grad_sampling_loc, grad_attn_weight + }; +} \ No newline at end of file diff --git a/dab_deformable_detr/ops/src/cuda/ms_deform_attn_cuda.h b/dab_deformable_detr/ops/src/cuda/ms_deform_attn_cuda.h new file mode 100644 index 0000000..c7ae53f --- /dev/null +++ b/dab_deformable_detr/ops/src/cuda/ms_deform_attn_cuda.h @@ -0,0 +1,30 @@ +/*! +************************************************************************************************** +* Deformable DETR +* Copyright (c) 2020 SenseTime. All Rights Reserved. +* Licensed under the Apache License, Version 2.0 [see LICENSE for details] +************************************************************************************************** +* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +************************************************************************************************** +*/ + +#pragma once +#include + +at::Tensor ms_deform_attn_cuda_forward( + const at::Tensor &value, + const at::Tensor &spatial_shapes, + const at::Tensor &level_start_index, + const at::Tensor &sampling_loc, + const at::Tensor &attn_weight, + const int im2col_step); + +std::vector ms_deform_attn_cuda_backward( + const at::Tensor &value, + const at::Tensor &spatial_shapes, + const at::Tensor &level_start_index, + const at::Tensor &sampling_loc, + const at::Tensor &attn_weight, + const at::Tensor &grad_output, + const int im2col_step); + diff --git a/dab_deformable_detr/ops/src/cuda/ms_deform_im2col_cuda.cuh b/dab_deformable_detr/ops/src/cuda/ms_deform_im2col_cuda.cuh new file mode 100644 index 0000000..6bc2acb --- /dev/null +++ b/dab_deformable_detr/ops/src/cuda/ms_deform_im2col_cuda.cuh @@ -0,0 +1,1327 @@ +/*! +************************************************************************** +* Deformable DETR +* Copyright (c) 2020 SenseTime. All Rights Reserved. +* Licensed under the Apache License, Version 2.0 [see LICENSE for details] +************************************************************************** +* Modified from DCN (https://github.com/msracver/Deformable-ConvNets) +* Copyright (c) 2018 Microsoft +************************************************************************** +*/ + +#include +#include +#include + +#include +#include + +#include + +#define CUDA_KERNEL_LOOP(i, n) \ + for (int i = blockIdx.x * blockDim.x + threadIdx.x; \ + i < (n); \ + i += blockDim.x * gridDim.x) + +const int CUDA_NUM_THREADS = 1024; +inline int GET_BLOCKS(const int N, const int num_threads) +{ + return (N + num_threads - 1) / num_threads; +} + + +template +__device__ scalar_t ms_deform_attn_im2col_bilinear(const scalar_t* &bottom_data, + const int &height, const int &width, const int &nheads, const int &channels, + const scalar_t &h, const scalar_t &w, const int &m, const int &c) +{ + const int h_low = floor(h); + const int w_low = floor(w); + const int h_high = h_low + 1; + const int w_high = w_low + 1; + + const scalar_t lh = h - h_low; + const scalar_t lw = w - w_low; + const scalar_t hh = 1 - lh, hw = 1 - lw; + + const int w_stride = nheads * channels; + const int h_stride = width * w_stride; + const int h_low_ptr_offset = h_low * h_stride; + const int h_high_ptr_offset = h_low_ptr_offset + h_stride; + const int w_low_ptr_offset = w_low * w_stride; + const int w_high_ptr_offset = w_low_ptr_offset + w_stride; + const int base_ptr = m * channels + c; + + scalar_t v1 = 0; + if (h_low >= 0 && w_low >= 0) + { + const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr; + v1 = bottom_data[ptr1]; + } + scalar_t v2 = 0; + if (h_low >= 0 && w_high <= width - 1) + { + const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr; + v2 = bottom_data[ptr2]; + } + scalar_t v3 = 0; + if (h_high <= height - 1 && w_low >= 0) + { + const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr; + v3 = bottom_data[ptr3]; + } + scalar_t v4 = 0; + if (h_high <= height - 1 && w_high <= width - 1) + { + const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr; + v4 = bottom_data[ptr4]; + } + + const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; + + const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + return val; +} + + +template +__device__ void ms_deform_attn_col2im_bilinear(const scalar_t* &bottom_data, + const int &height, const int &width, const int &nheads, const int &channels, + const scalar_t &h, const scalar_t &w, const int &m, const int &c, + const scalar_t &top_grad, + const scalar_t &attn_weight, + scalar_t* &grad_value, + scalar_t* grad_sampling_loc, + scalar_t* grad_attn_weight) +{ + const int h_low = floor(h); + const int w_low = floor(w); + const int h_high = h_low + 1; + const int w_high = w_low + 1; + + const scalar_t lh = h - h_low; + const scalar_t lw = w - w_low; + const scalar_t hh = 1 - lh, hw = 1 - lw; + + const int w_stride = nheads * channels; + const int h_stride = width * w_stride; + const int h_low_ptr_offset = h_low * h_stride; + const int h_high_ptr_offset = h_low_ptr_offset + h_stride; + const int w_low_ptr_offset = w_low * w_stride; + const int w_high_ptr_offset = w_low_ptr_offset + w_stride; + const int base_ptr = m * channels + c; + + const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; + const scalar_t top_grad_value = top_grad * attn_weight; + scalar_t grad_h_weight = 0, grad_w_weight = 0; + + scalar_t v1 = 0; + if (h_low >= 0 && w_low >= 0) + { + const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr; + v1 = bottom_data[ptr1]; + grad_h_weight -= hw * v1; + grad_w_weight -= hh * v1; + atomicAdd(grad_value+ptr1, w1*top_grad_value); + } + scalar_t v2 = 0; + if (h_low >= 0 && w_high <= width - 1) + { + const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr; + v2 = bottom_data[ptr2]; + grad_h_weight -= lw * v2; + grad_w_weight += hh * v2; + atomicAdd(grad_value+ptr2, w2*top_grad_value); + } + scalar_t v3 = 0; + if (h_high <= height - 1 && w_low >= 0) + { + const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr; + v3 = bottom_data[ptr3]; + grad_h_weight += hw * v3; + grad_w_weight -= lh * v3; + atomicAdd(grad_value+ptr3, w3*top_grad_value); + } + scalar_t v4 = 0; + if (h_high <= height - 1 && w_high <= width - 1) + { + const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr; + v4 = bottom_data[ptr4]; + grad_h_weight += lw * v4; + grad_w_weight += lh * v4; + atomicAdd(grad_value+ptr4, w4*top_grad_value); + } + + const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + *grad_attn_weight = top_grad * val; + *grad_sampling_loc = width * grad_w_weight * top_grad_value; + *(grad_sampling_loc + 1) = height * grad_h_weight * top_grad_value; +} + + +template +__device__ void ms_deform_attn_col2im_bilinear_gm(const scalar_t* &bottom_data, + const int &height, const int &width, const int &nheads, const int &channels, + const scalar_t &h, const scalar_t &w, const int &m, const int &c, + const scalar_t &top_grad, + const scalar_t &attn_weight, + scalar_t* &grad_value, + scalar_t* grad_sampling_loc, + scalar_t* grad_attn_weight) +{ + const int h_low = floor(h); + const int w_low = floor(w); + const int h_high = h_low + 1; + const int w_high = w_low + 1; + + const scalar_t lh = h - h_low; + const scalar_t lw = w - w_low; + const scalar_t hh = 1 - lh, hw = 1 - lw; + + const int w_stride = nheads * channels; + const int h_stride = width * w_stride; + const int h_low_ptr_offset = h_low * h_stride; + const int h_high_ptr_offset = h_low_ptr_offset + h_stride; + const int w_low_ptr_offset = w_low * w_stride; + const int w_high_ptr_offset = w_low_ptr_offset + w_stride; + const int base_ptr = m * channels + c; + + const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; + const scalar_t top_grad_value = top_grad * attn_weight; + scalar_t grad_h_weight = 0, grad_w_weight = 0; + + scalar_t v1 = 0; + if (h_low >= 0 && w_low >= 0) + { + const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr; + v1 = bottom_data[ptr1]; + grad_h_weight -= hw * v1; + grad_w_weight -= hh * v1; + atomicAdd(grad_value+ptr1, w1*top_grad_value); + } + scalar_t v2 = 0; + if (h_low >= 0 && w_high <= width - 1) + { + const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr; + v2 = bottom_data[ptr2]; + grad_h_weight -= lw * v2; + grad_w_weight += hh * v2; + atomicAdd(grad_value+ptr2, w2*top_grad_value); + } + scalar_t v3 = 0; + if (h_high <= height - 1 && w_low >= 0) + { + const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr; + v3 = bottom_data[ptr3]; + grad_h_weight += hw * v3; + grad_w_weight -= lh * v3; + atomicAdd(grad_value+ptr3, w3*top_grad_value); + } + scalar_t v4 = 0; + if (h_high <= height - 1 && w_high <= width - 1) + { + const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr; + v4 = bottom_data[ptr4]; + grad_h_weight += lw * v4; + grad_w_weight += lh * v4; + atomicAdd(grad_value+ptr4, w4*top_grad_value); + } + + const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + atomicAdd(grad_attn_weight, top_grad * val); + atomicAdd(grad_sampling_loc, width * grad_w_weight * top_grad_value); + atomicAdd(grad_sampling_loc + 1, height * grad_h_weight * top_grad_value); +} + + +template +__global__ void ms_deformable_im2col_gpu_kernel(const int n, + const scalar_t *data_value, + const int64_t *data_spatial_shapes, + const int64_t *data_level_start_index, + const scalar_t *data_sampling_loc, + const scalar_t *data_attn_weight, + const int batch_size, + const int spatial_size, + const int num_heads, + const int channels, + const int num_levels, + const int num_query, + const int num_point, + scalar_t *data_col) +{ + CUDA_KERNEL_LOOP(index, n) + { + int _temp = index; + const int c_col = _temp % channels; + _temp /= channels; + const int sampling_index = _temp; + const int m_col = _temp % num_heads; + _temp /= num_heads; + const int q_col = _temp % num_query; + _temp /= num_query; + const int b_col = _temp; + + scalar_t *data_col_ptr = data_col + index; + int data_weight_ptr = sampling_index * num_levels * num_point; + int data_loc_w_ptr = data_weight_ptr << 1; + const int qid_stride = num_heads * channels; + const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; + scalar_t col = 0; + + for (int l_col=0; l_col < num_levels; ++l_col) + { + const int level_start_id = data_level_start_index[l_col]; + const int spatial_h_ptr = l_col << 1; + const int spatial_h = data_spatial_shapes[spatial_h_ptr]; + const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; + const scalar_t *data_value_ptr = data_value + (data_value_ptr_init_offset + level_start_id * qid_stride); + for (int p_col=0; p_col < num_point; ++p_col) + { + const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; + const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; + const scalar_t weight = data_attn_weight[data_weight_ptr]; + + const scalar_t h_im = loc_h * spatial_h - 0.5; + const scalar_t w_im = loc_w * spatial_w - 0.5; + + if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) + { + col += ms_deform_attn_im2col_bilinear(data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col) * weight; + } + + data_weight_ptr += 1; + data_loc_w_ptr += 2; + } + } + *data_col_ptr = col; + } +} + +template +__global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1(const int n, + const scalar_t *grad_col, + const scalar_t *data_value, + const int64_t *data_spatial_shapes, + const int64_t *data_level_start_index, + const scalar_t *data_sampling_loc, + const scalar_t *data_attn_weight, + const int batch_size, + const int spatial_size, + const int num_heads, + const int channels, + const int num_levels, + const int num_query, + const int num_point, + scalar_t *grad_value, + scalar_t *grad_sampling_loc, + scalar_t *grad_attn_weight) +{ + CUDA_KERNEL_LOOP(index, n) + { + __shared__ scalar_t cache_grad_sampling_loc[blockSize * 2]; + __shared__ scalar_t cache_grad_attn_weight[blockSize]; + unsigned int tid = threadIdx.x; + int _temp = index; + const int c_col = _temp % channels; + _temp /= channels; + const int sampling_index = _temp; + const int m_col = _temp % num_heads; + _temp /= num_heads; + const int q_col = _temp % num_query; + _temp /= num_query; + const int b_col = _temp; + + const scalar_t top_grad = grad_col[index]; + + int data_weight_ptr = sampling_index * num_levels * num_point; + int data_loc_w_ptr = data_weight_ptr << 1; + const int grad_sampling_ptr = data_weight_ptr; + grad_sampling_loc += grad_sampling_ptr << 1; + grad_attn_weight += grad_sampling_ptr; + const int grad_weight_stride = 1; + const int grad_loc_stride = 2; + const int qid_stride = num_heads * channels; + const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; + + for (int l_col=0; l_col < num_levels; ++l_col) + { + const int level_start_id = data_level_start_index[l_col]; + const int spatial_h_ptr = l_col << 1; + const int spatial_h = data_spatial_shapes[spatial_h_ptr]; + const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; + const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride; + const scalar_t *data_value_ptr = data_value + value_ptr_offset; + scalar_t *grad_value_ptr = grad_value + value_ptr_offset; + + for (int p_col=0; p_col < num_point; ++p_col) + { + const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; + const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; + const scalar_t weight = data_attn_weight[data_weight_ptr]; + + const scalar_t h_im = loc_h * spatial_h - 0.5; + const scalar_t w_im = loc_w * spatial_w - 0.5; + *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0; + *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0; + *(cache_grad_attn_weight+threadIdx.x)=0; + if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) + { + ms_deform_attn_col2im_bilinear( + data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col, + top_grad, weight, grad_value_ptr, + cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x); + } + + __syncthreads(); + if (tid == 0) + { + scalar_t _grad_w=cache_grad_sampling_loc[0], _grad_h=cache_grad_sampling_loc[1], _grad_a=cache_grad_attn_weight[0]; + int sid=2; + for (unsigned int tid = 1; tid < blockSize; ++tid) + { + _grad_w += cache_grad_sampling_loc[sid]; + _grad_h += cache_grad_sampling_loc[sid + 1]; + _grad_a += cache_grad_attn_weight[tid]; + sid += 2; + } + + + *grad_sampling_loc = _grad_w; + *(grad_sampling_loc + 1) = _grad_h; + *grad_attn_weight = _grad_a; + } + __syncthreads(); + + data_weight_ptr += 1; + data_loc_w_ptr += 2; + grad_attn_weight += grad_weight_stride; + grad_sampling_loc += grad_loc_stride; + } + } + } +} + + +template +__global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2(const int n, + const scalar_t *grad_col, + const scalar_t *data_value, + const int64_t *data_spatial_shapes, + const int64_t *data_level_start_index, + const scalar_t *data_sampling_loc, + const scalar_t *data_attn_weight, + const int batch_size, + const int spatial_size, + const int num_heads, + const int channels, + const int num_levels, + const int num_query, + const int num_point, + scalar_t *grad_value, + scalar_t *grad_sampling_loc, + scalar_t *grad_attn_weight) +{ + CUDA_KERNEL_LOOP(index, n) + { + __shared__ scalar_t cache_grad_sampling_loc[blockSize * 2]; + __shared__ scalar_t cache_grad_attn_weight[blockSize]; + unsigned int tid = threadIdx.x; + int _temp = index; + const int c_col = _temp % channels; + _temp /= channels; + const int sampling_index = _temp; + const int m_col = _temp % num_heads; + _temp /= num_heads; + const int q_col = _temp % num_query; + _temp /= num_query; + const int b_col = _temp; + + const scalar_t top_grad = grad_col[index]; + + int data_weight_ptr = sampling_index * num_levels * num_point; + int data_loc_w_ptr = data_weight_ptr << 1; + const int grad_sampling_ptr = data_weight_ptr; + grad_sampling_loc += grad_sampling_ptr << 1; + grad_attn_weight += grad_sampling_ptr; + const int grad_weight_stride = 1; + const int grad_loc_stride = 2; + const int qid_stride = num_heads * channels; + const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; + + for (int l_col=0; l_col < num_levels; ++l_col) + { + const int level_start_id = data_level_start_index[l_col]; + const int spatial_h_ptr = l_col << 1; + const int spatial_h = data_spatial_shapes[spatial_h_ptr]; + const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; + const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride; + const scalar_t *data_value_ptr = data_value + value_ptr_offset; + scalar_t *grad_value_ptr = grad_value + value_ptr_offset; + + for (int p_col=0; p_col < num_point; ++p_col) + { + const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; + const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; + const scalar_t weight = data_attn_weight[data_weight_ptr]; + + const scalar_t h_im = loc_h * spatial_h - 0.5; + const scalar_t w_im = loc_w * spatial_w - 0.5; + *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0; + *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0; + *(cache_grad_attn_weight+threadIdx.x)=0; + if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) + { + ms_deform_attn_col2im_bilinear( + data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col, + top_grad, weight, grad_value_ptr, + cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x); + } + + __syncthreads(); + + for (unsigned int s=blockSize/2; s>0; s>>=1) + { + if (tid < s) { + const unsigned int xid1 = tid << 1; + const unsigned int xid2 = (tid + s) << 1; + cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s]; + cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2]; + cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1]; + } + __syncthreads(); + } + + if (tid == 0) + { + *grad_sampling_loc = cache_grad_sampling_loc[0]; + *(grad_sampling_loc + 1) = cache_grad_sampling_loc[1]; + *grad_attn_weight = cache_grad_attn_weight[0]; + } + __syncthreads(); + + data_weight_ptr += 1; + data_loc_w_ptr += 2; + grad_attn_weight += grad_weight_stride; + grad_sampling_loc += grad_loc_stride; + } + } + } +} + + +template +__global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v1(const int n, + const scalar_t *grad_col, + const scalar_t *data_value, + const int64_t *data_spatial_shapes, + const int64_t *data_level_start_index, + const scalar_t *data_sampling_loc, + const scalar_t *data_attn_weight, + const int batch_size, + const int spatial_size, + const int num_heads, + const int channels, + const int num_levels, + const int num_query, + const int num_point, + scalar_t *grad_value, + scalar_t *grad_sampling_loc, + scalar_t *grad_attn_weight) +{ + CUDA_KERNEL_LOOP(index, n) + { + extern __shared__ int _s[]; + scalar_t* cache_grad_sampling_loc = (scalar_t*)_s; + scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x; + unsigned int tid = threadIdx.x; + int _temp = index; + const int c_col = _temp % channels; + _temp /= channels; + const int sampling_index = _temp; + const int m_col = _temp % num_heads; + _temp /= num_heads; + const int q_col = _temp % num_query; + _temp /= num_query; + const int b_col = _temp; + + const scalar_t top_grad = grad_col[index]; + + int data_weight_ptr = sampling_index * num_levels * num_point; + int data_loc_w_ptr = data_weight_ptr << 1; + const int grad_sampling_ptr = data_weight_ptr; + grad_sampling_loc += grad_sampling_ptr << 1; + grad_attn_weight += grad_sampling_ptr; + const int grad_weight_stride = 1; + const int grad_loc_stride = 2; + const int qid_stride = num_heads * channels; + const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; + + for (int l_col=0; l_col < num_levels; ++l_col) + { + const int level_start_id = data_level_start_index[l_col]; + const int spatial_h_ptr = l_col << 1; + const int spatial_h = data_spatial_shapes[spatial_h_ptr]; + const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; + const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride; + const scalar_t *data_value_ptr = data_value + value_ptr_offset; + scalar_t *grad_value_ptr = grad_value + value_ptr_offset; + + for (int p_col=0; p_col < num_point; ++p_col) + { + const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; + const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; + const scalar_t weight = data_attn_weight[data_weight_ptr]; + + const scalar_t h_im = loc_h * spatial_h - 0.5; + const scalar_t w_im = loc_w * spatial_w - 0.5; + *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0; + *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0; + *(cache_grad_attn_weight+threadIdx.x)=0; + if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) + { + ms_deform_attn_col2im_bilinear( + data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col, + top_grad, weight, grad_value_ptr, + cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x); + } + + __syncthreads(); + if (tid == 0) + { + scalar_t _grad_w=cache_grad_sampling_loc[0], _grad_h=cache_grad_sampling_loc[1], _grad_a=cache_grad_attn_weight[0]; + int sid=2; + for (unsigned int tid = 1; tid < blockDim.x; ++tid) + { + _grad_w += cache_grad_sampling_loc[sid]; + _grad_h += cache_grad_sampling_loc[sid + 1]; + _grad_a += cache_grad_attn_weight[tid]; + sid += 2; + } + + + *grad_sampling_loc = _grad_w; + *(grad_sampling_loc + 1) = _grad_h; + *grad_attn_weight = _grad_a; + } + __syncthreads(); + + data_weight_ptr += 1; + data_loc_w_ptr += 2; + grad_attn_weight += grad_weight_stride; + grad_sampling_loc += grad_loc_stride; + } + } + } +} + +template +__global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2(const int n, + const scalar_t *grad_col, + const scalar_t *data_value, + const int64_t *data_spatial_shapes, + const int64_t *data_level_start_index, + const scalar_t *data_sampling_loc, + const scalar_t *data_attn_weight, + const int batch_size, + const int spatial_size, + const int num_heads, + const int channels, + const int num_levels, + const int num_query, + const int num_point, + scalar_t *grad_value, + scalar_t *grad_sampling_loc, + scalar_t *grad_attn_weight) +{ + CUDA_KERNEL_LOOP(index, n) + { + extern __shared__ int _s[]; + scalar_t* cache_grad_sampling_loc = (scalar_t*)_s; + scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x; + unsigned int tid = threadIdx.x; + int _temp = index; + const int c_col = _temp % channels; + _temp /= channels; + const int sampling_index = _temp; + const int m_col = _temp % num_heads; + _temp /= num_heads; + const int q_col = _temp % num_query; + _temp /= num_query; + const int b_col = _temp; + + const scalar_t top_grad = grad_col[index]; + + int data_weight_ptr = sampling_index * num_levels * num_point; + int data_loc_w_ptr = data_weight_ptr << 1; + const int grad_sampling_ptr = data_weight_ptr; + grad_sampling_loc += grad_sampling_ptr << 1; + grad_attn_weight += grad_sampling_ptr; + const int grad_weight_stride = 1; + const int grad_loc_stride = 2; + const int qid_stride = num_heads * channels; + const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; + + for (int l_col=0; l_col < num_levels; ++l_col) + { + const int level_start_id = data_level_start_index[l_col]; + const int spatial_h_ptr = l_col << 1; + const int spatial_h = data_spatial_shapes[spatial_h_ptr]; + const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; + const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride; + const scalar_t *data_value_ptr = data_value + value_ptr_offset; + scalar_t *grad_value_ptr = grad_value + value_ptr_offset; + + for (int p_col=0; p_col < num_point; ++p_col) + { + const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; + const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; + const scalar_t weight = data_attn_weight[data_weight_ptr]; + + const scalar_t h_im = loc_h * spatial_h - 0.5; + const scalar_t w_im = loc_w * spatial_w - 0.5; + *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0; + *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0; + *(cache_grad_attn_weight+threadIdx.x)=0; + if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) + { + ms_deform_attn_col2im_bilinear( + data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col, + top_grad, weight, grad_value_ptr, + cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x); + } + + __syncthreads(); + + for (unsigned int s=blockDim.x/2, spre=blockDim.x; s>0; s>>=1, spre>>=1) + { + if (tid < s) { + const unsigned int xid1 = tid << 1; + const unsigned int xid2 = (tid + s) << 1; + cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s]; + cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2]; + cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1]; + if (tid + (s << 1) < spre) + { + cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + (s << 1)]; + cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2 + (s << 1)]; + cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1 + (s << 1)]; + } + } + __syncthreads(); + } + + if (tid == 0) + { + *grad_sampling_loc = cache_grad_sampling_loc[0]; + *(grad_sampling_loc + 1) = cache_grad_sampling_loc[1]; + *grad_attn_weight = cache_grad_attn_weight[0]; + } + __syncthreads(); + + data_weight_ptr += 1; + data_loc_w_ptr += 2; + grad_attn_weight += grad_weight_stride; + grad_sampling_loc += grad_loc_stride; + } + } + } +} + +template +__global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks(const int n, + const scalar_t *grad_col, + const scalar_t *data_value, + const int64_t *data_spatial_shapes, + const int64_t *data_level_start_index, + const scalar_t *data_sampling_loc, + const scalar_t *data_attn_weight, + const int batch_size, + const int spatial_size, + const int num_heads, + const int channels, + const int num_levels, + const int num_query, + const int num_point, + scalar_t *grad_value, + scalar_t *grad_sampling_loc, + scalar_t *grad_attn_weight) +{ + CUDA_KERNEL_LOOP(index, n) + { + extern __shared__ int _s[]; + scalar_t* cache_grad_sampling_loc = (scalar_t*)_s; + scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x; + unsigned int tid = threadIdx.x; + int _temp = index; + const int c_col = _temp % channels; + _temp /= channels; + const int sampling_index = _temp; + const int m_col = _temp % num_heads; + _temp /= num_heads; + const int q_col = _temp % num_query; + _temp /= num_query; + const int b_col = _temp; + + const scalar_t top_grad = grad_col[index]; + + int data_weight_ptr = sampling_index * num_levels * num_point; + int data_loc_w_ptr = data_weight_ptr << 1; + const int grad_sampling_ptr = data_weight_ptr; + grad_sampling_loc += grad_sampling_ptr << 1; + grad_attn_weight += grad_sampling_ptr; + const int grad_weight_stride = 1; + const int grad_loc_stride = 2; + const int qid_stride = num_heads * channels; + const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; + + for (int l_col=0; l_col < num_levels; ++l_col) + { + const int level_start_id = data_level_start_index[l_col]; + const int spatial_h_ptr = l_col << 1; + const int spatial_h = data_spatial_shapes[spatial_h_ptr]; + const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; + const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride; + const scalar_t *data_value_ptr = data_value + value_ptr_offset; + scalar_t *grad_value_ptr = grad_value + value_ptr_offset; + + for (int p_col=0; p_col < num_point; ++p_col) + { + const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; + const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; + const scalar_t weight = data_attn_weight[data_weight_ptr]; + + const scalar_t h_im = loc_h * spatial_h - 0.5; + const scalar_t w_im = loc_w * spatial_w - 0.5; + *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0; + *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0; + *(cache_grad_attn_weight+threadIdx.x)=0; + if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) + { + ms_deform_attn_col2im_bilinear( + data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col, + top_grad, weight, grad_value_ptr, + cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x); + } + + __syncthreads(); + + for (unsigned int s=blockDim.x/2, spre=blockDim.x; s>0; s>>=1, spre>>=1) + { + if (tid < s) { + const unsigned int xid1 = tid << 1; + const unsigned int xid2 = (tid + s) << 1; + cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s]; + cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2]; + cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1]; + if (tid + (s << 1) < spre) + { + cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + (s << 1)]; + cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2 + (s << 1)]; + cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1 + (s << 1)]; + } + } + __syncthreads(); + } + + if (tid == 0) + { + atomicAdd(grad_sampling_loc, cache_grad_sampling_loc[0]); + atomicAdd(grad_sampling_loc + 1, cache_grad_sampling_loc[1]); + atomicAdd(grad_attn_weight, cache_grad_attn_weight[0]); + } + __syncthreads(); + + data_weight_ptr += 1; + data_loc_w_ptr += 2; + grad_attn_weight += grad_weight_stride; + grad_sampling_loc += grad_loc_stride; + } + } + } +} + + +template +__global__ void ms_deformable_col2im_gpu_kernel_gm(const int n, + const scalar_t *grad_col, + const scalar_t *data_value, + const int64_t *data_spatial_shapes, + const int64_t *data_level_start_index, + const scalar_t *data_sampling_loc, + const scalar_t *data_attn_weight, + const int batch_size, + const int spatial_size, + const int num_heads, + const int channels, + const int num_levels, + const int num_query, + const int num_point, + scalar_t *grad_value, + scalar_t *grad_sampling_loc, + scalar_t *grad_attn_weight) +{ + CUDA_KERNEL_LOOP(index, n) + { + int _temp = index; + const int c_col = _temp % channels; + _temp /= channels; + const int sampling_index = _temp; + const int m_col = _temp % num_heads; + _temp /= num_heads; + const int q_col = _temp % num_query; + _temp /= num_query; + const int b_col = _temp; + + const scalar_t top_grad = grad_col[index]; + + int data_weight_ptr = sampling_index * num_levels * num_point; + int data_loc_w_ptr = data_weight_ptr << 1; + const int grad_sampling_ptr = data_weight_ptr; + grad_sampling_loc += grad_sampling_ptr << 1; + grad_attn_weight += grad_sampling_ptr; + const int grad_weight_stride = 1; + const int grad_loc_stride = 2; + const int qid_stride = num_heads * channels; + const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; + + for (int l_col=0; l_col < num_levels; ++l_col) + { + const int level_start_id = data_level_start_index[l_col]; + const int spatial_h_ptr = l_col << 1; + const int spatial_h = data_spatial_shapes[spatial_h_ptr]; + const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; + const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride; + const scalar_t *data_value_ptr = data_value + value_ptr_offset; + scalar_t *grad_value_ptr = grad_value + value_ptr_offset; + + for (int p_col=0; p_col < num_point; ++p_col) + { + const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; + const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; + const scalar_t weight = data_attn_weight[data_weight_ptr]; + + const scalar_t h_im = loc_h * spatial_h - 0.5; + const scalar_t w_im = loc_w * spatial_w - 0.5; + if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) + { + ms_deform_attn_col2im_bilinear_gm( + data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col, + top_grad, weight, grad_value_ptr, + grad_sampling_loc, grad_attn_weight); + } + data_weight_ptr += 1; + data_loc_w_ptr += 2; + grad_attn_weight += grad_weight_stride; + grad_sampling_loc += grad_loc_stride; + } + } + } +} + + +template +void ms_deformable_im2col_cuda(cudaStream_t stream, + const scalar_t* data_value, + const int64_t* data_spatial_shapes, + const int64_t* data_level_start_index, + const scalar_t* data_sampling_loc, + const scalar_t* data_attn_weight, + const int batch_size, + const int spatial_size, + const int num_heads, + const int channels, + const int num_levels, + const int num_query, + const int num_point, + scalar_t* data_col) +{ + const int num_kernels = batch_size * num_query * num_heads * channels; + const int num_actual_kernels = batch_size * num_query * num_heads * channels; + const int num_threads = CUDA_NUM_THREADS; + ms_deformable_im2col_gpu_kernel + <<>>( + num_kernels, data_value, data_spatial_shapes, data_level_start_index, data_sampling_loc, data_attn_weight, + batch_size, spatial_size, num_heads, channels, num_levels, num_query, num_point, data_col); + + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) + { + printf("error in ms_deformable_im2col_cuda: %s\n", cudaGetErrorString(err)); + } + +} + +template +void ms_deformable_col2im_cuda(cudaStream_t stream, + const scalar_t* grad_col, + const scalar_t* data_value, + const int64_t * data_spatial_shapes, + const int64_t * data_level_start_index, + const scalar_t * data_sampling_loc, + const scalar_t * data_attn_weight, + const int batch_size, + const int spatial_size, + const int num_heads, + const int channels, + const int num_levels, + const int num_query, + const int num_point, + scalar_t* grad_value, + scalar_t* grad_sampling_loc, + scalar_t* grad_attn_weight) +{ + const int num_threads = (channels > CUDA_NUM_THREADS)?CUDA_NUM_THREADS:channels; + const int num_kernels = batch_size * num_query * num_heads * channels; + const int num_actual_kernels = batch_size * num_query * num_heads * channels; + if (channels > 1024) + { + if ((channels & 1023) == 0) + { + ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + } + else + { + ms_deformable_col2im_gpu_kernel_gm + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + } + } + else{ + switch(channels) + { + case 1: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + case 2: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + case 4: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + case 8: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + case 16: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + case 32: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + case 64: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + case 128: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + case 256: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + case 512: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + case 1024: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + default: + if (channels < 64) + { + ms_deformable_col2im_gpu_kernel_shm_reduce_v1 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + } + else + { + ms_deformable_col2im_gpu_kernel_shm_reduce_v2 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + } + } + } + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) + { + printf("error in ms_deformable_col2im_cuda: %s\n", cudaGetErrorString(err)); + } + +} \ No newline at end of file diff --git a/dab_deformable_detr/ops/src/ms_deform_attn.h b/dab_deformable_detr/ops/src/ms_deform_attn.h new file mode 100644 index 0000000..ac0ef2e --- /dev/null +++ b/dab_deformable_detr/ops/src/ms_deform_attn.h @@ -0,0 +1,62 @@ +/*! +************************************************************************************************** +* Deformable DETR +* Copyright (c) 2020 SenseTime. All Rights Reserved. +* Licensed under the Apache License, Version 2.0 [see LICENSE for details] +************************************************************************************************** +* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +************************************************************************************************** +*/ + +#pragma once + +#include "cpu/ms_deform_attn_cpu.h" + +#ifdef WITH_CUDA +#include "cuda/ms_deform_attn_cuda.h" +#endif + + +at::Tensor +ms_deform_attn_forward( + const at::Tensor &value, + const at::Tensor &spatial_shapes, + const at::Tensor &level_start_index, + const at::Tensor &sampling_loc, + const at::Tensor &attn_weight, + const int im2col_step) +{ + if (value.type().is_cuda()) + { +#ifdef WITH_CUDA + return ms_deform_attn_cuda_forward( + value, spatial_shapes, level_start_index, sampling_loc, attn_weight, im2col_step); +#else + AT_ERROR("Not compiled with GPU support"); +#endif + } + AT_ERROR("Not implemented on the CPU"); +} + +std::vector +ms_deform_attn_backward( + const at::Tensor &value, + const at::Tensor &spatial_shapes, + const at::Tensor &level_start_index, + const at::Tensor &sampling_loc, + const at::Tensor &attn_weight, + const at::Tensor &grad_output, + const int im2col_step) +{ + if (value.type().is_cuda()) + { +#ifdef WITH_CUDA + return ms_deform_attn_cuda_backward( + value, spatial_shapes, level_start_index, sampling_loc, attn_weight, grad_output, im2col_step); +#else + AT_ERROR("Not compiled with GPU support"); +#endif + } + AT_ERROR("Not implemented on the CPU"); +} + diff --git a/dab_deformable_detr/ops/src/vision.cpp b/dab_deformable_detr/ops/src/vision.cpp new file mode 100644 index 0000000..2201f63 --- /dev/null +++ b/dab_deformable_detr/ops/src/vision.cpp @@ -0,0 +1,16 @@ +/*! +************************************************************************************************** +* Deformable DETR +* Copyright (c) 2020 SenseTime. All Rights Reserved. +* Licensed under the Apache License, Version 2.0 [see LICENSE for details] +************************************************************************************************** +* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +************************************************************************************************** +*/ + +#include "ms_deform_attn.h" + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("ms_deform_attn_forward", &ms_deform_attn_forward, "ms_deform_attn_forward"); + m.def("ms_deform_attn_backward", &ms_deform_attn_backward, "ms_deform_attn_backward"); +} diff --git a/dab_deformable_detr/ops/test.py b/dab_deformable_detr/ops/test.py new file mode 100644 index 0000000..8dbf6d5 --- /dev/null +++ b/dab_deformable_detr/ops/test.py @@ -0,0 +1,89 @@ +# ------------------------------------------------------------------------------------------------ +# Deformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------------------------------ +# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +# ------------------------------------------------------------------------------------------------ + +from __future__ import absolute_import +from __future__ import print_function +from __future__ import division + +import time +import torch +import torch.nn as nn +from torch.autograd import gradcheck + +from functions.ms_deform_attn_func import MSDeformAttnFunction, ms_deform_attn_core_pytorch + + +N, M, D = 1, 2, 2 +Lq, L, P = 2, 2, 2 +shapes = torch.as_tensor([(6, 4), (3, 2)], dtype=torch.long).cuda() +level_start_index = torch.cat((shapes.new_zeros((1, )), shapes.prod(1).cumsum(0)[:-1])) +S = sum([(H*W).item() for H, W in shapes]) + + +torch.manual_seed(3) + + +@torch.no_grad() +def check_forward_equal_with_pytorch_double(): + value = torch.rand(N, S, M, D).cuda() * 0.01 + sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda() + attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5 + attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True) + im2col_step = 2 + output_pytorch = ms_deform_attn_core_pytorch(value.double(), shapes, sampling_locations.double(), attention_weights.double()).detach().cpu() + output_cuda = MSDeformAttnFunction.apply(value.double(), shapes, level_start_index, sampling_locations.double(), attention_weights.double(), im2col_step).detach().cpu() + fwdok = torch.allclose(output_cuda, output_pytorch) + max_abs_err = (output_cuda - output_pytorch).abs().max() + max_rel_err = ((output_cuda - output_pytorch).abs() / output_pytorch.abs()).max() + + print(f'* {fwdok} check_forward_equal_with_pytorch_double: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}') + + +@torch.no_grad() +def check_forward_equal_with_pytorch_float(): + value = torch.rand(N, S, M, D).cuda() * 0.01 + sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda() + attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5 + attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True) + im2col_step = 2 + output_pytorch = ms_deform_attn_core_pytorch(value, shapes, sampling_locations, attention_weights).detach().cpu() + output_cuda = MSDeformAttnFunction.apply(value, shapes, level_start_index, sampling_locations, attention_weights, im2col_step).detach().cpu() + fwdok = torch.allclose(output_cuda, output_pytorch, rtol=1e-2, atol=1e-3) + max_abs_err = (output_cuda - output_pytorch).abs().max() + max_rel_err = ((output_cuda - output_pytorch).abs() / output_pytorch.abs()).max() + + print(f'* {fwdok} check_forward_equal_with_pytorch_float: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}') + + +def check_gradient_numerical(channels=4, grad_value=True, grad_sampling_loc=True, grad_attn_weight=True): + + value = torch.rand(N, S, M, channels).cuda() * 0.01 + sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda() + attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5 + attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True) + im2col_step = 2 + func = MSDeformAttnFunction.apply + + value.requires_grad = grad_value + sampling_locations.requires_grad = grad_sampling_loc + attention_weights.requires_grad = grad_attn_weight + + gradok = gradcheck(func, (value.double(), shapes, level_start_index, sampling_locations.double(), attention_weights.double(), im2col_step)) + + print(f'* {gradok} check_gradient_numerical(D={channels})') + + +if __name__ == '__main__': + check_forward_equal_with_pytorch_double() + check_forward_equal_with_pytorch_float() + + for channels in [30, 32, 64, 71, 1025, 2048, 3096]: + check_gradient_numerical(channels, True, True, True) + + + diff --git a/dab_deformable_detr/position_encoding.py b/dab_deformable_detr/position_encoding.py new file mode 100644 index 0000000..cc3741f --- /dev/null +++ b/dab_deformable_detr/position_encoding.py @@ -0,0 +1,102 @@ +# ------------------------------------------------------------------------ +# DAB-DETR +# Copyright (c) 2022 IDEA. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ +# DModified from eformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ +# Modified from DETR (https://github.com/facebookresearch/detr) +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +# ------------------------------------------------------------------------ + +""" +Various positional encodings for the transformer. +""" +import math +import torch +from torch import nn + +from utils.miscdetr import NestedTensor + + +class PositionEmbeddingSine(nn.Module): + """ + This is a more standard version of the position embedding, very similar to the one + used by the Attention is all you need paper, generalized to work on images. + """ + def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None): + super().__init__() + self.num_pos_feats = num_pos_feats + self.temperature = temperature + self.normalize = normalize + if scale is not None and normalize is False: + raise ValueError("normalize should be True if scale is passed") + if scale is None: + scale = 2 * math.pi + self.scale = scale + + def forward(self, tensor_list: NestedTensor): + x = tensor_list.tensors + mask = tensor_list.mask + assert mask is not None + # mask = (mask.squeeze(1)<0) # 不加也行 + not_mask = ~mask + y_embed = not_mask.cumsum(1, dtype=torch.float32) + x_embed = not_mask.cumsum(2, dtype=torch.float32) + if self.normalize: + eps = 1e-6 + y_embed = (y_embed - 0.5) / (y_embed[:, -1:, :] + eps) * self.scale + x_embed = (x_embed - 0.5) / (x_embed[:, :, -1:] + eps) * self.scale + + dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) + dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) + + pos_x = x_embed[:, :, :, None] / dim_t + pos_y = y_embed[:, :, :, None] / dim_t + pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) + pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) + pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) + return pos + + +class PositionEmbeddingLearned(nn.Module): + """ + Absolute pos embedding, learned. + """ + def __init__(self, num_pos_feats=256): + super().__init__() + self.row_embed = nn.Embedding(50, num_pos_feats) + self.col_embed = nn.Embedding(50, num_pos_feats) + self.reset_parameters() + + def reset_parameters(self): + nn.init.uniform_(self.row_embed.weight) + nn.init.uniform_(self.col_embed.weight) + + def forward(self, tensor_list: NestedTensor): + x = tensor_list.tensors + h, w = x.shape[-2:] + i = torch.arange(w, device=x.device) + j = torch.arange(h, device=x.device) + x_emb = self.col_embed(i) + y_emb = self.row_embed(j) + pos = torch.cat([ + x_emb.unsqueeze(0).repeat(h, 1, 1), + y_emb.unsqueeze(1).repeat(1, w, 1), + ], dim=-1).permute(2, 0, 1).unsqueeze(0).repeat(x.shape[0], 1, 1, 1) + return pos + + +def build_position_encoding(cfg): + N_steps = cfg.hidden_dim // 2 + if cfg.position_embedding in ('v2', 'sine'): + # TODO find a better way of exposing other arguments + position_embedding = PositionEmbeddingSine(N_steps, normalize=True) + elif cfg.position_embedding in ('v3', 'learned'): + position_embedding = PositionEmbeddingLearned(N_steps) + else: + raise ValueError(f"not supported {cfg.position_embedding}") + + return position_embedding diff --git a/dataset.py b/dataset.py new file mode 100644 index 0000000..ae0dde5 --- /dev/null +++ b/dataset.py @@ -0,0 +1,432 @@ +import numpy as np +import torch +import torch.utils.data +import cv2 +import os.path as osp +from config import cfg +from utils.preprocessing import load_img, load_skeleton, get_bbox, process_bbox, augmentation, transform_input_to_output_space, trans_point2d +from utils.transforms import world2cam, cam2pixel, pixel2cam +from utils.vis import vis_keypoints, vis_3d_keypoints +import json +from pycocotools.coco import COCO +from tqdm import tqdm +import pickle + + +class Dataset(torch.utils.data.Dataset): + def __init__(self, transform, mode): + self.mode = mode # train, test, val + self.img_path = cfg.interhand_images_path + self.annot_path = cfg.interhand_anno_dir + self.datalist_dir = cfg.datalistDir + if self.mode == 'val': + self.rootnet_output_path = '../rootnet_output/rootnet_interhand2.6m_output_val.json' + else: + self.rootnet_output_path = '../rootnet_output/rootnet_interhand2.6m_output_test.json' + self.transform = transform + self.joint_num = 21 # single hand + self.root_joint_idx = {'right': 20, 'left': 41} + self.joint_type = {'right': np.arange(0,self.joint_num), 'left': np.arange(self.joint_num,self.joint_num*2)} + self.skeleton = load_skeleton(osp.join(self.annot_path, 'skeleton.txt'), self.joint_num*2) + self.use_single_hand_dataset = cfg.use_single_hand_dataset + self.use_inter_hand_dataset = cfg.use_inter_hand_dataset + self.vis = False + + ## use the total Interhand2.6M dataset + datalist_file_path_sh = osp.join(self.datalist_dir , mode + '_datalist_sh_all.pkl') + datalist_file_path_ih = osp.join(self.datalist_dir , mode + '_datalist_ih_all.pkl') + + # generate_new_datalist : whether to get datalist from existing file + generate_new_datalist = True + if osp.exists(datalist_file_path_sh) and osp.exists(datalist_file_path_ih): + if (osp.getsize(datalist_file_path_sh) + osp.getsize(datalist_file_path_ih)) != 0: + generate_new_datalist = False + + ## if the datalist is empty or doesn't exist, generate the pkl file and save the datalist + if generate_new_datalist is True: + self.datalist = [] + self.datalist_sh = [] + self.datalist_ih = [] + self.sequence_names = [] + + # load annotation + print("Load annotation from " + osp.join(self.annot_path, self.mode)) + db = COCO(osp.join(self.annot_path, self.mode, 'InterHand2.6M_' + self.mode + '_data.json')) + with open(osp.join(self.annot_path, self.mode, 'InterHand2.6M_' + self.mode + '_camera.json')) as f: + cameras = json.load(f) + with open(osp.join(self.annot_path, self.mode, 'InterHand2.6M_' + self.mode + '_joint_3d.json')) as f: + joints = json.load(f) + + # rootnet is not used + if (self.mode == 'val' or self.mode == 'test') and cfg.trans_test == 'rootnet': + print("Get bbox and root depth from " + self.rootnet_output_path) + rootnet_result = {} + with open(self.rootnet_output_path) as f: + annot = json.load(f) + for i in range(len(annot)): + rootnet_result[str(annot[i]['annot_id'])] = annot[i] + else: + print("Get bbox and root depth from groundtruth annotation") + + # get images and annotations + for aid in tqdm(list(db.anns.keys())[::1]): + ann = db.anns[aid] + image_id = ann['image_id'] + img = db.loadImgs(image_id)[0] + hand_type = ann['hand_type'] + capture_id = img['capture'] + subject = img['subject'] + seq_name = img['seq_name'] + cam = img['camera'] + frame_idx = img['frame_idx'] + img_path = osp.join(self.img_path, self.mode, img['file_name']) + + campos, camrot = np.array(cameras[str(capture_id)]['campos'][str(cam)], dtype=np.float32), np.array(cameras[str(capture_id)]['camrot'][str(cam)], dtype=np.float32) + focal, princpt = np.array(cameras[str(capture_id)]['focal'][str(cam)], dtype=np.float32), np.array(cameras[str(capture_id)]['princpt'][str(cam)], dtype=np.float32) + joint_world = np.array(joints[str(capture_id)][str(frame_idx)]['world_coord'], dtype=np.float32) + joint_cam = world2cam(joint_world.transpose(1,0), camrot, campos.reshape(3,1)).transpose(1,0) + joint_img = cam2pixel(joint_cam, focal, princpt)[:,:2] + joint_valid = np.array(ann['joint_valid'],dtype=np.float32).reshape(self.joint_num*2) + + ## Filter the data that does not meet the training requirements. + ## All preprocessing refers to the baseline of Interhand2.6M(ECCV2020). + # if root is not valid -> root-relative 3D pose is also not valid. Therefore, mark all joints as invalid + joint_valid[self.joint_type['right']] *= joint_valid[self.root_joint_idx['right']] + joint_valid[self.joint_type['left']] *= joint_valid[self.root_joint_idx['left']] + # hand_type = ann['hand_type'] + hand_type_valid = np.array((ann['hand_type_valid']), dtype=np.float32) + + # rootnet is not used + if (self.mode == 'val' or self.mode == 'test') and cfg.trans_test == 'rootnet': + bbox = np.array(rootnet_result[str(aid)]['bbox'],dtype=np.float32) + abs_depth = {'right': rootnet_result[str(aid)]['abs_depth'][0], 'left': rootnet_result[str(aid)]['abs_depth'][1]} + else: + img_width, img_height = img['width'], img['height'] + bbox = np.array(ann['bbox'],dtype=np.float32) # x,y,w,h + bbox = process_bbox(bbox, (img_height, img_width)) + abs_depth = {'right': joint_cam[self.root_joint_idx['right'],2], 'left': joint_cam[self.root_joint_idx['left'],2]} #根节点的深度值,以此为参考 + + cam_param = {'focal': focal, 'princpt': princpt} + joint = {'cam_coord': joint_cam, 'img_coord': joint_img, 'valid': joint_valid} + data = {'img_path': img_path, 'seq_name': seq_name, 'cam_param': cam_param, + 'bbox': bbox, 'joint': joint, 'hand_type': hand_type, 'hand_type_valid': hand_type_valid, + 'abs_depth': abs_depth, 'file_name': img['file_name'], 'capture': capture_id, 'cam': cam, + 'frame': frame_idx, 'subject': subject, 'imgid': image_id + } + + if hand_type == 'right' or hand_type == 'left': + if self.use_single_hand_dataset is True: + self.datalist_sh.append(data) + elif hand_type == 'interacting': + if self.use_inter_hand_dataset is True: + self.datalist_ih.append(data) + if seq_name not in self.sequence_names: + self.sequence_names.append(seq_name) + + # Save the generated datalist to pkl file, easy to debug + with open(datalist_file_path_sh, 'wb') as fs: + pickle.dump(self.datalist_sh, fs) + with open(datalist_file_path_ih, 'wb') as fi: + pickle.dump(self.datalist_ih, fi) + + + # Directly load the datalist saved in the previous file + else: + if self.use_single_hand_dataset is True: + with open (datalist_file_path_sh, 'rb') as fsl: + self.datalist_sh = pickle.load(fsl) + else: + self.datalist_sh = [] + if self.use_inter_hand_dataset is True: + with open (datalist_file_path_ih, 'rb') as fil: + self.datalist_ih = pickle.load(fil) + else: + self.datalist_ih = [] + + self.datalist = self.datalist_sh + self.datalist_ih + print('Number of annotations in single hand sequences: ' + str(len(self.datalist_sh))) + print('Number of annotations in interacting hand sequences: ' + str(len(self.datalist_ih))) + + + def handtype_str2array(self, hand_type): + if hand_type == 'right': + return np.array([1,0], dtype=np.float32) + elif hand_type == 'left': + return np.array([0,1], dtype=np.float32) + elif hand_type == 'interacting': + return np.array([1,1], dtype=np.float32) + else: + assert 0, print('Not supported hand type: ' + hand_type) + + def __len__(self): + return len(self.datalist) + + def __getitem__(self, idx): + data = self.datalist[idx] + img_path, bbox, joint, hand_type, hand_type_valid = data['img_path'], data['bbox'], data['joint'], data['hand_type'], data['hand_type_valid'] + joint_cam = joint['cam_coord'].copy(); joint_img = joint['img_coord'].copy(); joint_valid = joint['valid'].copy(); + hand_type = self.handtype_str2array(hand_type) + joint_coord = np.concatenate((joint_img, joint_cam[:,2,None].copy()),1) + seq_name = data['seq_name'] + contact_vis_np = np.zeros((32, 2)).astype(np.float32) + + # image load + img = load_img(img_path) + + # augmentation + img, joint_coord, joint_valid, hand_type, inv_trans = augmentation(img, bbox, joint_coord, joint_valid, hand_type, self.mode, self.joint_type) + rel_root_depth = np.array([joint_coord[self.root_joint_idx['left'],2] - joint_coord[self.root_joint_idx['right'],2]],dtype=np.float32).reshape(1) + root_valid = np.array([joint_valid[self.root_joint_idx['right']] * joint_valid[self.root_joint_idx['left']]])*1.0 + + # transform to output heatmap space + joint_coord, joint_valid, rel_root_depth, root_valid =\ + transform_input_to_output_space(joint_coord, joint_valid, rel_root_depth, root_valid, self.root_joint_idx, self.joint_type) + + # Some images are blank, filter for training + if np.sum(img) < 1e-4 : + joint_valid *= 0 + root_valid *= 0 + hand_type_valid *= 0 + contact_vis_np *= 0 + + img = self.transform(img.astype(np.float32)) / 255. + + # use zero mask. + mask = np.zeros((img.shape[1], img.shape[2])).astype(np.bool) + mask = self.transform(mask.astype(np.uint8)) + + inputs = {'img': img, 'mask': mask} + targets = {'joint_coord': joint_coord, 'rel_root_depth': rel_root_depth, 'hand_type': hand_type} + meta_info = {'joint_valid': joint_valid, 'root_valid': root_valid, 'hand_type_valid': hand_type_valid, + 'inv_trans': inv_trans, 'capture': int(data['capture']), 'cam': int(data['cam']), 'frame': int(data['frame'])} + return inputs, targets, meta_info + + + def evaluate(self, preds): + print() + print('Evaluation start...') + + gts = self.datalist + preds_joint_coord, inv_trans, joint_valid_used = preds['joint_coord'], preds['inv_trans'], preds['joint_valid'] + assert len(gts) == len(preds_joint_coord) + sample_num = len(gts) + + mpjpe_sh = [[] for _ in range(self.joint_num*2)] + mpjpe_ih = [[] for _ in range(self.joint_num*2)] + mpjpe_sh_2d = [[] for _ in range(self.joint_num*2)] + mpjpe_sh_3d = [[] for _ in range(self.joint_num*2)] + mpjpe_ih_2d = [[] for _ in range(self.joint_num*2)] + mpjpe_ih_3d = [[] for _ in range(self.joint_num*2)] + tot_err = [] + mpjpe_dict = {} + + + mrrpe = [] + acc_hand_cls = 0; hand_cls_cnt = 0; + for n in tqdm(range(sample_num),ncols=150): + vis = False + mpjpe_per_data_list = [] + mpjpe_per_data = 0 + + data = gts[n] + bbox, cam_param, joint, gt_hand_type, hand_type_valid = data['bbox'], data['cam_param'], data['joint'], data['hand_type'], data['hand_type_valid'] + hand_type = data['hand_type'] + + focal = cam_param['focal'] + princpt = cam_param['princpt'] + gt_joint_coord = joint['cam_coord'] + gt_joint_img = joint['img_coord'] + + ## use original joint_valid param. + joint_valid = joint['valid'] + # joint_valid = joint_valid_used[n] + + # restore xy coordinates to original image space + pred_joint_coord_img = preds_joint_coord[n].copy() + pred_joint_coord_img[:,0] = pred_joint_coord_img[:,0]/cfg.output_hm_shape[2]*cfg.input_img_shape[1] + pred_joint_coord_img[:,1] = pred_joint_coord_img[:,1]/cfg.output_hm_shape[1]*cfg.input_img_shape[0] + for j in range(self.joint_num*2): + pred_joint_coord_img[j,:2] = trans_point2d(pred_joint_coord_img[j,:2],inv_trans[n]) + + # restore depth to original camera space + pred_joint_coord_img[:,2] = (pred_joint_coord_img[:,2]/cfg.output_hm_shape[0] * 2 - 1) * (cfg.bbox_3d_size/2) + + # add root joint depth + pred_joint_coord_img[self.joint_type['right'],2] += data['abs_depth']['right'] + pred_joint_coord_img[self.joint_type['left'],2] += data['abs_depth']['left'] + + # back project to camera coordinate system + pred_joint_coord_cam = pixel2cam(pred_joint_coord_img, focal, princpt) + + # root joint alignment + for h in ('right', 'left'): + pred_joint_coord_cam[self.joint_type[h]] = pred_joint_coord_cam[self.joint_type[h]] - pred_joint_coord_cam[self.root_joint_idx[h],None,:] + gt_joint_coord[self.joint_type[h]] = gt_joint_coord[self.joint_type[h]] - gt_joint_coord[self.root_joint_idx[h],None,:] + + + # mpjpe + ## xyz mpjpe + for j in range(self.joint_num*2): + if joint_valid[j]: ## 在这里,限制了只加载valid的坐标值 + if gt_hand_type == 'right' or gt_hand_type == 'left': + mpjpe_sh[j].append(np.sqrt(np.sum((pred_joint_coord_cam[j] - gt_joint_coord[j])**2))) + mpjpe_per_data_list.append(np.sqrt(np.sum((pred_joint_coord_cam[j] - gt_joint_coord[j])**2))) + # continue + else: + mpjpe_ih[j].append(np.sqrt(np.sum((pred_joint_coord_cam[j] - gt_joint_coord[j])**2))) + mpjpe_per_data_list.append(np.sqrt(np.sum((pred_joint_coord_cam[j] - gt_joint_coord[j])**2))) + + + ## xy mpjpe + for j in range(self.joint_num*2): + if joint_valid[j]: + if gt_hand_type == 'right' or gt_hand_type == 'left': + mpjpe_sh_2d[j].append(np.sqrt(np.sum((pred_joint_coord_cam[j,:2] - gt_joint_coord[j,:2])**2))) + # continue + else: + mpjpe_ih_2d[j].append(np.sqrt(np.sum((pred_joint_coord_cam[j,:2] - gt_joint_coord[j,:2])**2))) + ## depth mpjpe + for j in range(self.joint_num*2): + if joint_valid[j]: + if gt_hand_type == 'right' or gt_hand_type == 'left': + mpjpe_sh_3d[j].append(np.sqrt(np.sum((pred_joint_coord_cam[j,2] - gt_joint_coord[j,2])**2))) + # continue + else: + mpjpe_ih_3d[j].append(np.sqrt(np.sum((pred_joint_coord_cam[j,2] - gt_joint_coord[j,2])**2))) + + vis_2d = False + if vis_2d: + img_path = data['img_path'] + cvimg = cv2.imread(img_path, cv2.IMREAD_COLOR | cv2.IMREAD_IGNORE_ORIENTATION) + _img = cvimg[:,:,::-1].transpose(2,0,1) + vis_kps = pred_joint_coord_img.copy() + vis_kps_gt = gt_joint_img.copy() + vis_valid = joint_valid.copy() + capture = str(data['capture']) + cam = str(data['cam']) + frame = str(data['frame']) + filename = 'out_' + str(n) + '_' + gt_hand_type + '.jpg' + vis_keypoints(_img, vis_kps, vis_kps_gt, bbox, vis_valid, self.skeleton, filename) + print('vis 2d over') + + + vis_3d = False + if vis_3d: + filename = 'out_' + str(n) + '_3d.jpg' + vis_3d_cam = pred_joint_coord_cam.copy() + vis_3d_cam_left = pred_joint_coord_cam[self.joint_type['left']].copy() + vis_3d_cam_left[:,2] = pred_joint_coord_cam[self.joint_type['left'],2] + vis_3d_cam_right = pred_joint_coord_cam[self.joint_type['right']].copy() + vis_3d_cam_right[:,2] = pred_joint_coord_cam[self.joint_type['right'],2] + vis_3d = np.concatenate((vis_3d_cam_left, vis_3d_cam_right), axis= 0) + vis_3d_keypoints(vis_3d, joint_valid, self.skeleton, filename) + print('vis 3d over') + + + if hand_cls_cnt > 0: + handness_accuracy = acc_hand_cls / hand_cls_cnt + print('Handedness accuracy: ' + str(handness_accuracy)) + if len(mrrpe) > 0: + mrrpe_num = sum(mrrpe)/len(mrrpe) + print('MRRPE: ' + str(mrrpe_num)) + print() + + + + if self.use_inter_hand_dataset is True and self.use_single_hand_dataset is True: + print('..................MPJPE FOR TOTAL HAND..................') + eval_summary = 'MPJPE for each joint: \n' + for j in range(self.joint_num*2): + tot_err_j = np.mean(np.concatenate((np.stack(mpjpe_sh[j]), np.stack(mpjpe_ih[j])))) + joint_name = self.skeleton[j]['name'] + eval_summary += (joint_name + ': %.2f, ' % tot_err_j) + tot_err.append(tot_err_j) + print(eval_summary) + tot_err_mean = np.mean(tot_err) + print('MPJPE for all hand sequences: %.2f' % (tot_err_mean)) + mpjpe_dict['total'] = tot_err_mean + print() + + if self.use_single_hand_dataset is True: + print('..................MPJPE FOR SINGLE HAND..................') + ## xyz + eval_summary = 'MPJPE for each joint: \n' + for j in range(self.joint_num*2): + mpjpe_sh[j] = np.mean(np.stack(mpjpe_sh[j])) + joint_name = self.skeleton[j]['name'] + eval_summary += (joint_name + ': %.2f, ' % mpjpe_sh[j]) + print(eval_summary) + mpjpe_sh_mean = np.mean(mpjpe_sh) + print('MPJPE for single hand sequences: %.2f' % (mpjpe_sh_mean)) + mpjpe_dict['single_hand_total'] = mpjpe_sh_mean + print() + + ## xy + eval_summary_2d = 'MPJPE for each joint 2d: \n' + for j in range(self.joint_num*2): + mpjpe_sh_2d[j] = np.mean(np.stack(mpjpe_sh_2d[j])) + joint_name = self.skeleton[j]['name'] + eval_summary_2d += (joint_name + ': %.2f, ' % mpjpe_sh_2d[j]) + print(eval_summary_2d) + mpjpe_sh_2d_mean = np.mean(mpjpe_sh_2d) + print('MPJPE for single hand sequences 2d: %.2f' % (mpjpe_sh_2d_mean)) + mpjpe_dict['single_hand_2d'] = mpjpe_sh_2d_mean + print() + + ## z + eval_summary_3d = 'MPJPE for each joint depth: \n' + for j in range(self.joint_num*2): + mpjpe_sh_3d[j] = np.mean(np.stack(mpjpe_sh_3d[j])) + joint_name = self.skeleton[j]['name'] + eval_summary_3d += (joint_name + ': %.2f, ' % mpjpe_sh_3d[j]) + print(eval_summary_3d) + mpjpe_sh_3d_mean = np.mean(mpjpe_sh_3d) + print('MPJPE for single hand sequences 3d: %.2f' % (mpjpe_sh_3d_mean)) + mpjpe_dict['single_hand_depth'] = mpjpe_sh_3d_mean + print() + + + if self.use_inter_hand_dataset is True: + print('..................MPJPE FOR INTER HAND..................') + ## xyz + eval_summary = 'MPJPE for each joint: \n' + for j in range(self.joint_num*2): + mpjpe_ih[j] = np.mean(np.stack(mpjpe_ih[j])) + joint_name = self.skeleton[j]['name'] + eval_summary += (joint_name + ': %.2f, ' % mpjpe_ih[j]) + print(eval_summary) + mpjpe_ih_mean = np.mean(mpjpe_ih) + print('MPJPE for interacting hand sequences: %.2f' % (mpjpe_ih_mean)) + mpjpe_dict['inter_hand_total'] = mpjpe_ih_mean + print() + + ## xy + eval_summary_2d = 'MPJPE for each joint 2d: \n' + for j in range(self.joint_num*2): + mpjpe_ih_2d[j] = np.mean(np.stack(mpjpe_ih_2d[j])) + joint_name = self.skeleton[j]['name'] + eval_summary_2d += (joint_name + ': %.2f, ' % mpjpe_ih_2d[j]) + print(eval_summary_2d) + mpjpe_ih_2d_mean = np.mean(mpjpe_ih_2d) + print('MPJPE for interacting hand sequences 2d: %.2f' % (mpjpe_ih_2d_mean)) + mpjpe_dict['inter_hand_2d'] = mpjpe_ih_2d_mean + print() + + ## z + eval_summary_3d = 'MPJPE for each joint depth: \n' + for j in range(self.joint_num*2): + mpjpe_ih_3d[j] = np.mean(np.stack(mpjpe_ih_3d[j])) + joint_name = self.skeleton[j]['name'] + eval_summary_3d += (joint_name + ': %.2f, ' % mpjpe_ih_3d[j]) + print(eval_summary_3d) + mpjpe_ih_3d_mean = np.mean(mpjpe_ih_3d) + print('MPJPE for interacting hand sequences 3d: %.2f' % (mpjpe_ih_3d_mean)) + mpjpe_dict['inter_hand_depth'] = mpjpe_ih_3d_mean + print() + + + if hand_cls_cnt > 0 and len(mrrpe) > 0: + return mpjpe_dict, handness_accuracy, mrrpe_num + else: + return mpjpe_dict, None, None + \ No newline at end of file diff --git a/logger.py b/logger.py new file mode 100644 index 0000000..b3807e2 --- /dev/null +++ b/logger.py @@ -0,0 +1,57 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +import logging +import os + +OK = '\033[92m' +WARNING = '\033[93m' +FAIL = '\033[91m' +END = '\033[0m' + +PINK = '\033[95m' +BLUE = '\033[94m' +GREEN = OK +RED = FAIL +WHITE = END +YELLOW = WARNING + +class colorlogger(): + def __init__(self, log_dir, log_name='train_logs.txt'): + # set log + self._logger = logging.getLogger(log_name) + self._logger.setLevel(logging.INFO) + log_file = os.path.join(log_dir, log_name) + if not os.path.exists(log_dir): + os.makedirs(log_dir) + file_log = logging.FileHandler(log_file, mode='a') + file_log.setLevel(logging.INFO) + console_log = logging.StreamHandler() + console_log.setLevel(logging.INFO) + formatter = logging.Formatter( + "{}%(asctime)s{} %(message)s".format(GREEN, END), + "%m-%d %H:%M:%S") + file_log.setFormatter(formatter) + console_log.setFormatter(formatter) + self._logger.addHandler(file_log) + self._logger.addHandler(console_log) + + def debug(self, msg): + self._logger.debug(str(msg)) + + def info(self, msg): + self._logger.info(str(msg)) + + def warning(self, msg): + self._logger.warning(WARNING + 'WRN: ' + str(msg) + END) + + def critical(self, msg): + self._logger.critical(RED + 'CRI: ' + str(msg) + END) + + def error(self, msg): + self._logger.error(RED + 'ERR: ' + str(msg) + END) + diff --git a/model.py b/model.py new file mode 100644 index 0000000..5284aff --- /dev/null +++ b/model.py @@ -0,0 +1,252 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +import torch +import torch.nn as nn +import torch.nn.functional as F +import math +import numpy as np +from config import cfg +from anchor import * + +from nets.layer import MLP +from nets.position_encoding import build_position_encoding + +# use dab-deformable-detr +from dab_deformable_detr.deformable_transformer import build_deforamble_transformer +from dab_deformable_detr.backbone import build_backbone +from utils.miscdetr import NestedTensor +import copy + + +def _get_clones(module, N): + return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) + +class A2J_model(nn.Module): + def __init__(self, backbone_net, position_embedding, transformer, num_classes, num_queries, num_feature_levels, + with_box_refine=True, two_stage=False, use_dab=True, + num_patterns=0, anchor_refpoints_xy=True, fix_anchor=False, is_3D=True, use_lvl_weights=False): + """ Initializes the model. + Parameters: + backbone_net: torch module of the backbone to be used. + transformer: torch module of the transformer architecture. + num_classes: number of object classes, given 42 + num_queries: number of object queries, given 256*3 + num_feature_levels: number of feature layers used form backbone, default = 4 + with_box_refine: iterative bounding box refinement + two_stage: two-stage Deformable DETR, default is False + use_dab: using dab-deformable-detr + num_patterns: number of pattern embeddings + anchor_refpoints_xy: init the x,y of anchor boxes to A2J anchors and freeze them. + fix_anchor: if fix the reference points as the initial anchor points to stop renew them + is_3D: if the model regresses 3D coords of the keypoints + """ + super(A2J_model, self).__init__() + self.backbone = backbone_net + self.position_embedding = position_embedding + self.transformer = transformer + self.num_classes = num_classes # default = 42 + self.num_queries = num_queries # default = 768 + self.num_feature_levels = num_feature_levels # default = 4 + self.with_box_refine = with_box_refine + self.two_stage = two_stage + self.use_dab = use_dab + self.num_patterns = num_patterns + self.anchor_refpoints_xy = anchor_refpoints_xy + self.fix_anchor = fix_anchor + self.is_3D = is_3D + self.use_lvl_weights = use_lvl_weights + self.kernel_size = cfg.kernel_size + + hidden_dim = transformer.d_model # =cfg.hidden_dim, default = 256 + self.bbox_embed_anchor = MLP(hidden_dim, hidden_dim, 2, 3) + if self.is_3D: + self.bbox_embed_keypoints = MLP(hidden_dim, hidden_dim, self.num_classes *3, 3) ## 3D coord + else: + assert self.is_3D is False + self.bbox_embed_keypoints = MLP(hidden_dim, hidden_dim, self.num_classes *2, 3) ## only xy-coord + self.anchor_weights = MLP(hidden_dim, hidden_dim, self.num_classes *1, 3) + + if not two_stage: + if not use_dab: + self.query_embed = nn.Embedding(num_queries, hidden_dim*2) + else: + self.tgt_embed = nn.Embedding(num_queries, hidden_dim) + self.refpoint_embed = nn.Embedding(num_queries, 3) + + if anchor_refpoints_xy: + self.anchors = generate_all_anchors_3d() + self.anchors = torch.from_numpy(self.anchors).cuda().float() + self.refpoint_embed.weight.data = self.anchors + self.refpoint_embed.weight.data.requires_grad = False + + + if num_feature_levels > 1: + num_backbone_outs = len(self.backbone.strides) #8,16,32 + input_proj_list = [] + for _ in range(num_backbone_outs): + in_channels = self.backbone.num_channels[_] # [512, 1024, 2048] + input_proj_list.append(nn.Sequential( + nn.Conv2d(in_channels, hidden_dim, kernel_size=1), + nn.GroupNorm(32, hidden_dim), + )) + for _ in range(num_feature_levels - num_backbone_outs): + input_proj_list.append(nn.Sequential( + nn.Conv2d(in_channels, hidden_dim, kernel_size=3, stride=2, padding=1), + nn.GroupNorm(32, hidden_dim), + )) + in_channels = hidden_dim + self.input_proj = nn.ModuleList(input_proj_list) + else: + self.input_proj = nn.ModuleList([ + nn.Sequential( + nn.Conv2d(self.backbone.num_channels[0], hidden_dim, kernel_size=1), + nn.GroupNorm(32, hidden_dim), + )]) + + prior_prob = 0.01 + bias_value = -math.log((1 - prior_prob) / prior_prob) + + for proj in self.input_proj: + nn.init.xavier_uniform_(proj[0].weight, gain=1) + nn.init.constant_(proj[0].bias, 0) + + num_pred = (transformer.decoder.num_layers + 1) if two_stage else transformer.decoder.num_layers + if with_box_refine: + self.bbox_embed_anchor = _get_clones(self.bbox_embed_anchor, num_pred) + self.transformer.decoder.bbox_embed = self.bbox_embed_anchor + else: + nn.init.constant_(self.bbox_embed_anchor.layers[-1].bias.data[2:], -2.0) + self.bbox_embed_anchor = nn.ModuleList([self.bbox_embed_anchor for _ in range(num_pred)]) + self.transformer.decoder.bbox_embed = None + if two_stage: + for box_embed in self.bbox_embed_anchor: + nn.init.constant_(box_embed.layers[-1].bias.data[2:], 0.0) + + # for final output + nn.init.constant_(self.bbox_embed_anchor[0].layers[-1].bias.data[2:], -2.0) + self.bbox_embed_keypoints = _get_clones(self.bbox_embed_keypoints, num_pred) + nn.init.constant_(self.bbox_embed_keypoints[0].layers[-1].bias.data[2:], -2.0) + self.anchor_weights = _get_clones(self.anchor_weights, num_pred) + nn.init.constant_(self.anchor_weights[0].layers[-1].bias.data[2:], -2.0) + + self.generate_keypoints_coord_new = generate_keypoints_coord_new(self.num_classes, is_3D = self.is_3D) + + + + # def forward(self, x): + def forward(self, inputs, targets, meta_info, mode): + input_img = inputs['img'] + input_mask = inputs['mask'] + batch_size = input_img.shape[0] + samples = NestedTensor(input_img,input_mask.squeeze(1)) + + ## get pyramid features + features, pos = self.backbone(samples) + srcs = [] + masks = [] + + for l, feat in enumerate(features): + src, mask = feat.decompose() + srcs.append(self.input_proj[l](src)) + masks.append(mask) + assert mask is not None + + if self.num_feature_levels > len(srcs): + _len_srcs = len(srcs) + for l in range(_len_srcs, self.num_feature_levels): + if l == _len_srcs: + src = self.input_proj[l](features[-1].tensors) + else: + src = self.input_proj[l](srcs[-1]) + m = samples.mask + mask = F.interpolate(m[None].float(), size=src.shape[-2:]).to(torch.bool)[0] + pos_l = self.backbone[1](NestedTensor(src, mask)).to(src.dtype) + srcs.append(src) + masks.append(mask) + pos.append(pos_l) + + tgt_embed = self.tgt_embed.weight + refanchor = self.refpoint_embed.weight + + ## Convert refanchor to [0,1] range + refanchor = refanchor / cfg.output_hm_shape_all + query_embeds = torch.cat((tgt_embed, refanchor), dim=1) + + ## Transformer module. Enhance features. + hs, init_reference, inter_references, enc_outputs_class, enc_outputs_coord_unact = self.transformer(srcs, masks, pos, query_embeds) + + outputs_coords = [] + outputs_weights = [] + references = [] + + ## Predict offset and weights for each layer. + ## Total 6 layers, which is the same as enc/dec layers. + for lvl in range(hs.shape[0]): + if lvl == 0: + reference = init_reference.squeeze(0).expand(batch_size,-1 ,-1) + else: + reference = inter_references[lvl - 1] + outputs_weight = self.anchor_weights[lvl](hs[lvl]) + tmp = self.bbox_embed_keypoints[lvl](hs[lvl]) + assert reference.shape[-1] == 3 + + ## convert result to [0,256] range, same to size of output img + reference = reference * cfg.output_hm_shape_all + outputs_coord = tmp * cfg.output_hm_shape_all + + outputs_coords.append(outputs_coord) + outputs_weights.append(outputs_weight) + references.append(reference) + + total_outputs_coord = torch.stack(outputs_coords) ## A2J-offsets + total_outputs_weights = torch.stack(outputs_weights) ## A2J-weights + total_references = torch.stack(references) ## A2J-anchors + + ## generate final coords + keypoints_coord, anchor = self.generate_keypoints_coord_new(total_outputs_coord, total_outputs_weights, total_references) + + if mode == 'test': + ## use the result of last layer as the final result + pred_keypoints = keypoints_coord[-1] + out = {} + out['joint_coord'] =pred_keypoints + if 'inv_trans' in meta_info: + out['inv_trans'] = meta_info['inv_trans'] + if 'joint_coord' in targets: + out['target_joint'] = targets['joint_coord'] + if 'joint_valid' in meta_info: + out['joint_valid'] = meta_info['joint_valid'] + if 'hand_type_valid' in meta_info: + out['hand_type_valid'] = meta_info['hand_type_valid'] + return out + + + +def get_model(mode, joint_num): + backbone_net = build_backbone(cfg) + transformer = build_deforamble_transformer(cfg) + position_embedding = build_position_encoding(cfg) + model = A2J_model(backbone_net, + position_embedding, + transformer, + num_classes = joint_num * 2, + num_queries = cfg.num_queries, + num_feature_levels = cfg.num_feature_levels, + two_stage=cfg.two_stage, + use_dab=True, + num_patterns=cfg.num_patterns, + anchor_refpoints_xy=cfg.anchor_refpoints_xy, + fix_anchor = cfg.fix_anchor, + is_3D=cfg.is_3D, + use_lvl_weights=cfg.use_lvl_weights) + + ## Statistical Model Size + print('BackboneNet No. of Params = %d M'%(sum(p.numel() for p in backbone_net.parameters() if p.requires_grad)/1e6)) + print('Transformer No. of Params = %d M'%(sum(p.numel() for p in transformer.parameters() if p.requires_grad)/1e6)) + print('Total No. of Params = %d M' % (sum(p.numel() for p in model.parameters() if p.requires_grad)/1e6)) + return model diff --git a/nets/layer.py b/nets/layer.py new file mode 100644 index 0000000..68a4784 --- /dev/null +++ b/nets/layer.py @@ -0,0 +1,183 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn.parameter import Parameter +from torch.nn.modules.module import Module +import math +from config import cfg + +def make_linear_layers(feat_dims, relu_final=True): + layers = [] + for i in range(len(feat_dims)-1): + layers.append(nn.Linear(feat_dims[i], feat_dims[i+1])) + + # Do not use ReLU for final estimation + if i < len(feat_dims)-2 or (i == len(feat_dims)-2 and relu_final): + layers.append(nn.ReLU(inplace=True)) + + return nn.Sequential(*layers) + +def make_conv_layers(feat_dims, kernel=3, stride=1, padding=1, bnrelu_final=True): + layers = [] + for i in range(len(feat_dims)-1): + layers.append( + nn.Conv2d( + in_channels=feat_dims[i], + out_channels=feat_dims[i+1], + kernel_size=kernel, + stride=stride, + padding=padding + )) + # Do not use BN and ReLU for final estimation + if i < len(feat_dims)-2 or (i == len(feat_dims)-2 and bnrelu_final): + layers.append(nn.BatchNorm2d(feat_dims[i+1])) + layers.append(nn.ReLU(inplace=True)) + + return nn.Sequential(*layers) + +def make_deconv_layers(feat_dims, bnrelu_final=True): + layers = [] + for i in range(len(feat_dims)-1): + layers.append( + nn.ConvTranspose2d( + in_channels=feat_dims[i], + out_channels=feat_dims[i+1], + kernel_size=4, + stride=2, + padding=1, + output_padding=0, + bias=False)) + + # Do not use BN and ReLU for final estimation + if i < len(feat_dims)-2 or (i == len(feat_dims)-2 and bnrelu_final): + layers.append(nn.BatchNorm2d(feat_dims[i+1])) + layers.append(nn.ReLU(inplace=True)) + + return nn.Sequential(*layers) + +class Interpolate(nn.Module): + def __init__(self, scale_factor, mode): + super(Interpolate, self).__init__() + self.interp = F.interpolate + self.scale_factor = scale_factor + self.mode = mode + + def forward(self, x): + x = self.interp(x, scale_factor=self.scale_factor, mode=self.mode, align_corners=False) + return x + +def make_upsample_layers(feat_dims, bnrelu_final=True): + layers = [] + for i in range(len(feat_dims)-1): + layers.append( + Interpolate(2, 'bilinear')) + layers.append( + nn.Conv2d( + in_channels=feat_dims[i], + out_channels=feat_dims[i+1], + kernel_size=3, + stride=1, + padding=1 + )) + + # Do not use BN and ReLU for final estimation + if i < len(feat_dims)-2 or (i == len(feat_dims)-2 and bnrelu_final): + layers.append(nn.BatchNorm2d(feat_dims[i+1])) + layers.append(nn.ReLU(inplace=True)) + + return nn.Sequential(*layers) + +class ResBlock(nn.Module): + def __init__(self, in_feat, out_feat): + super(ResBlock, self).__init__() + self.in_feat = in_feat + self.out_feat = out_feat + + self.conv = make_conv_layers([in_feat, out_feat, out_feat], bnrelu_final=False) + self.bn = nn.BatchNorm2d(out_feat) + if self.in_feat != self.out_feat: + self.shortcut_conv = nn.Conv2d(in_feat,out_feat,kernel_size=1,stride=1,padding=0) + self.shortcut_bn = nn.BatchNorm2d(out_feat) + + def forward(self, input): + x = self.bn(self.conv(input)) + if self.in_feat != self.out_feat: + x = F.relu(x + self.shortcut_bn(self.shortcut_conv(input))) + else: + x = F.relu(x + input) + return x + +def make_conv3d_layers(feat_dims, kernel=3, stride=1, padding=1, bnrelu_final=True): + layers = [] + for i in range(len(feat_dims)-1): + layers.append( + nn.Conv3d( + in_channels=feat_dims[i], + out_channels=feat_dims[i+1], + kernel_size=kernel, + stride=stride, + padding=padding + )) + # Do not use BN and ReLU for final estimation + if i < len(feat_dims)-2 or (i == len(feat_dims)-2 and bnrelu_final): + layers.append(nn.BatchNorm3d(feat_dims[i+1])) + layers.append(nn.ReLU(inplace=True)) + + return nn.Sequential(*layers) + +def make_deconv3d_layers(feat_dims, bnrelu_final=True): + layers = [] + for i in range(len(feat_dims)-1): + layers.append( + nn.ConvTranspose3d( + in_channels=feat_dims[i], + out_channels=feat_dims[i+1], + kernel_size=4, + stride=2, + padding=1, + output_padding=0, + bias=False)) + + # Do not use BN and ReLU for final estimation + if i < len(feat_dims)-2 or (i == len(feat_dims)-2 and bnrelu_final): + layers.append(nn.BatchNorm3d(feat_dims[i+1])) + layers.append(nn.ReLU(inplace=True)) + + return nn.Sequential(*layers) + + +class MLP(nn.Module): + """ Very simple multi-layer perceptron (also called FFN)""" + + def __init__(self, input_dim, hidden_dim, output_dim, num_layers, is_activation_last=False): + super().__init__() + self.num_layers = num_layers + self.is_activation_last = is_activation_last + if not isinstance(hidden_dim, list): + h = [hidden_dim] * (num_layers - 1) + else: + assert isinstance(hidden_dim, list), 'hidden_dim arg should be list or a number' + assert len(hidden_dim) == num_layers-1, 'len(hidden_dim) != num_layers-1' + h = hidden_dim + + # self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])) + all_dims = [input_dim] + h + [output_dim] + self.layers = nn.ModuleList(nn.Linear(all_dims[i], all_dims[i+1]) for i in range(num_layers)) + + def forward(self, x): + for i, layer in enumerate(self.layers[:-1]): + x = F.relu(layer(x)) + + if self.is_activation_last: + x = F.relu(self.layers[-1](x)) + else: + x = self.layers[-1](x) + + return x \ No newline at end of file diff --git a/nets/position_encoding.py b/nets/position_encoding.py new file mode 100644 index 0000000..81ff3e5 --- /dev/null +++ b/nets/position_encoding.py @@ -0,0 +1,158 @@ +# Copyright (c) 2020 Graz University of Technology All rights reserved. + +#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +""" +Various positional encodings for the transformer. +""" +import math +import torch +from torch import nn +import matplotlib.pyplot as plt +from utils.misc import NestedTensor +from nets.layer import MLP +from nets.layer import make_linear_layers, make_conv_layers, make_deconv_layers + + +class PositionEmbeddingSine(nn.Module): + """ + This is a more standard version of the position embedding, very similar to the one + used by the Attention is all you need paper, generalized to work on images. + """ + def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None): + super().__init__() + self.num_pos_feats = num_pos_feats + self.temperature = temperature + self.normalize = normalize + if scale is not None and normalize is False: + raise ValueError("normalize should be True if scale is passed") + if scale is None: + scale = 2 * math.pi + self.scale = scale + + def forward(self, img, mask): + assert mask is not None + mask = (mask.squeeze(1)<0) + + not_mask = ~ mask + y_embed = not_mask.cumsum(1, dtype=torch.float32) + x_embed = not_mask.cumsum(2, dtype=torch.float32) + if self.normalize: + eps = 1e-6 + y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale + x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale + + dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=img.device) + dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) + + # pos_x = x_embed[:, :, :, None] / dim_t + # pos_y = y_embed[:, :, :, None] / dim_t + pos_x = torch.div(x_embed[:, :, :, None], dim_t) + pos_y = torch.div(y_embed[:, :, :, None], dim_t) + + pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) + pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) + pos = torch.cat((pos_y, pos_x), dim=3) + pos = pos.permute(0, 3, 1, 2) # N x hidden_dim x H x W + + return pos ,mask + + +class PositionEmbeddingLearned(nn.Module): + """ + Absolute pos embedding, learned. + """ + def __init__(self, num_pos_feats=256): + super().__init__() + self.row_embed = nn.Embedding(50, num_pos_feats) + self.col_embed = nn.Embedding(50, num_pos_feats) + self.reset_parameters() + + def reset_parameters(self): + nn.init.uniform_(self.row_embed.weight) + nn.init.uniform_(self.col_embed.weight) + + def forward(self, img, mask): + x = img + h, w = x.shape[-2:] + i = torch.arange(w, device=x.device) + j = torch.arange(h, device=x.device) + x_emb = self.col_embed(i) + y_emb = self.row_embed(j) + pos = torch.cat([ + x_emb.unsqueeze(0).repeat(h, 1, 1), + y_emb.unsqueeze(1).repeat(1, w, 1), + ], dim=-1).permute(2, 0, 1).unsqueeze(0).repeat(x.shape[0], 1, 1, 1) + return pos + +class PositionEmbeddingConvLearned(nn.Module): + def __init__(self, num_pos_feats=256): + super(PositionEmbeddingConvLearned, self).__init__() + self.input_size = 8 + self.num_emb_layers = [32, 64, 128, 128, num_pos_feats] # 8, 16, 32, 64, 128 + self.embed = nn.Embedding(self.input_size*self.input_size, self.num_emb_layers[0]) + + self.deconv_layers = [] + for i in range(len(self.num_emb_layers)-1): + if i == len(self.num_emb_layers)-1: + self.deconv_layers.append(make_deconv_layers([self.num_emb_layers[i], self.num_emb_layers[i + 1]], bnrelu_final=False).to('cuda')) + else: + self.deconv_layers.append(make_deconv_layers([self.num_emb_layers[i], self.num_emb_layers[i+1]]).to('cuda')) + + self.reset_parameters() + + def reset_parameters(self): + nn.init.uniform_(self.embed.weight) + + def forward(self, img, mask): + input = self.embed.weight.view(self.input_size, self.input_size, self.num_emb_layers[0]).permute(2,0,1).unsqueeze(0).cuda() + for i in range(len(self.deconv_layers)): + input = self.deconv_layers[i](input) + input = input.repeat([img.shape[0],1,1,1]) + return input + + +class PositionEmbeddingLinearLearned(nn.Module): + def __init__(self, num_pos_feats=256): + super(PositionEmbeddingLinearLearned, self).__init__() + self.linear = MLP(input_dim=2, hidden_dim=[16, 32, 64, 128], output_dim=num_pos_feats, num_layers=5) + + def forward(self, img, mask): + xx, yy = torch.meshgrid(torch.arange(img.shape[3]), torch.arange(img.shape[2])) + pixel_locs = torch.stack([yy, xx], dim=2).to(torch.float).to(img.device) # 128 x 128 x 2 + pos = self.linear(pixel_locs.view(-1, 2)) # 128*128 x 256 + pos = pos.view(img.shape[2], img.shape[3], pos.shape[-1]).permute(2,0,1) + pos = pos.unsqueeze(0).repeat([img.shape[0],1,1,1]) + return pos + + +class PositionEmbeddingSimpleCat(nn.Module): + def __init__(self, num_pos_feats=256): + super(PositionEmbeddingSimpleCat, self).__init__() + + def forward(self, img, mask): + xx, yy = torch.meshgrid(torch.arange(img.shape[3]), torch.arange(img.shape[2])) + pos = torch.stack([yy, xx], dim=2).to(torch.float).to(img.device) # 128 x 128 x 2 + pos = pos.permute(2,0,1).unsqueeze(0).repeat([img.shape[0],1,1,1]) + return pos + + +def build_position_encoding(args): + N_steps = args.hidden_dim // 2 + if args.position_embedding in ('v2', 'sine'): + position_embedding = PositionEmbeddingSine(N_steps, normalize=True) + elif args.position_embedding in ('v3', 'learned'): + position_embedding = PositionEmbeddingLearned(N_steps) + elif args.position_embedding in ('v4', 'convLearned'): + position_embedding = PositionEmbeddingConvLearned(args.hidden_dim) + elif args.position_embedding in ('v5', 'linearLearned'): + position_embedding = PositionEmbeddingLinearLearned(args.hidden_dim) + elif args.position_embedding in ('v6', 'simpleCat'): + position_embedding = PositionEmbeddingSine(16, normalize=True) + # position_embedding = PositionEmbeddingSimpleCat(N_steps) + + else: + raise ValueError(f"not supported {args.position_embedding}") + + return position_embedding diff --git a/test.py b/test.py new file mode 100644 index 0000000..692a6ba --- /dev/null +++ b/test.py @@ -0,0 +1,85 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +import argparse +from tqdm import tqdm +import numpy as np +import cv2 +from config import cfg +import torch +from base import Tester +from utils.vis import vis_keypoints +import torch.backends.cudnn as cudnn +from utils.transforms import flip + +import time + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument('--gpu', type=str, default='6,7', dest='gpu_ids') + parser.add_argument('--test_set', type=str, default='test', dest='test_set') + args = parser.parse_args() + + if not args.gpu_ids: + assert 0, "Please set propoer gpu ids" + + if '-' in args.gpu_ids: + gpus = args.gpu_ids.split('-') + gpus[0] = int(gpus[0]) + gpus[1] = int(gpus[1]) + 1 + args.gpu_ids = ','.join(map(lambda x: str(x), list(range(*gpus)))) + + return args + +def test(): + args = parse_args() + cfg.set_args(args.gpu_ids) + cudnn.benchmark = True + + if cfg.dataset == 'InterHand2.6M': + assert args.test_set, 'Test set is required. Select one of test/val' + else: + args.test_set = 'test' + + tester = Tester() + tester._make_batch_generator(args.test_set) + tester._make_model() + + preds = {'joint_coord': [], 'inv_trans': [], 'joint_valid': [] } + + timer = [] + + + with torch.no_grad(): + for itr, (inputs, targets, meta_info) in enumerate(tqdm(tester.batch_generator,ncols=150)): + + # forward + start = time.time() + out = tester.model(inputs, targets, meta_info, 'test') + end = time.time() + + joint_coord_out = out['joint_coord'].cpu().numpy() + inv_trans = out['inv_trans'].cpu().numpy() + joint_vaild = out['joint_valid'].cpu().numpy() + + preds['joint_coord'].append(joint_coord_out) + preds['inv_trans'].append(inv_trans) + preds['joint_valid'].append(joint_vaild) + + timer.append(end-start) + + + # evaluate + preds = {k: np.concatenate(v) for k,v in preds.items()} + + mpjpe_dict, hand_accuracy, mrrpe = tester._evaluate(preds) + print(mpjpe_dict) + print('time per batch is',np.mean(timer)) + +if __name__ == "__main__": + test() diff --git a/timer.py b/timer.py new file mode 100644 index 0000000..e1e20d7 --- /dev/null +++ b/timer.py @@ -0,0 +1,68 @@ +# -------------------------------------------------------- +# Fast R-CNN +# Copyright (c) 2015 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# Written by Ross Girshick +# -------------------------------------------------------- +# -----------------------LICENSE-------------------------- +# Fast R-CNN +# +# Copyright (c) Microsoft Corporation +# +# All rights reserved. +# +# MIT License +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included +# in all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR +# OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, +# ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR +# OTHER DEALINGS IN THE SOFTWARE. +# -------------------------------------------------------- +# +# The code is from the publicly available implementation of Fast R-CNN +# https://github.com/rbgirshick/fast-rcnn/blob/master/lib/utils/timer.py + +import time + +class Timer(object): + """A simple timer.""" + def __init__(self): + self.total_time = 0. + self.calls = 0 + self.start_time = 0. + self.diff = 0. + self.average_time = 0. + self.warm_up = 0 + + def tic(self): + # using time.time instead of time.clock because time time.clock + # does not normalize for multithreading + self.start_time = time.time() + + def toc(self, average=True): + self.diff = time.time() - self.start_time + if self.warm_up < 10: + self.warm_up += 1 + return self.diff + else: + self.total_time += self.diff + self.calls += 1 + self.average_time = self.total_time / self.calls + + if average: + return self.average_time + else: + return self.diff diff --git a/utils/__init__.py b/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/utils/box_ops.py b/utils/box_ops.py new file mode 100644 index 0000000..ca29592 --- /dev/null +++ b/utils/box_ops.py @@ -0,0 +1,96 @@ +# ------------------------------------------------------------------------ +# Deformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ +# Modified from DETR (https://github.com/facebookresearch/detr) +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +# ------------------------------------------------------------------------ + +""" +Utilities for bounding box manipulation and GIoU. +""" +import torch +from torchvision.ops.boxes import box_area + + +def box_cxcywh_to_xyxy(x): + x_c, y_c, w, h = x.unbind(-1) + b = [(x_c - 0.5 * w), (y_c - 0.5 * h), + (x_c + 0.5 * w), (y_c + 0.5 * h)] + return torch.stack(b, dim=-1) + + +def box_xyxy_to_cxcywh(x): + x0, y0, x1, y1 = x.unbind(-1) + b = [(x0 + x1) / 2, (y0 + y1) / 2, + (x1 - x0), (y1 - y0)] + return torch.stack(b, dim=-1) + + +# modified from torchvision to also return the union +def box_iou(boxes1, boxes2): + area1 = box_area(boxes1) + area2 = box_area(boxes2) + + lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2] + rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2] + + wh = (rb - lt).clamp(min=0) # [N,M,2] + inter = wh[:, :, 0] * wh[:, :, 1] # [N,M] + + union = area1[:, None] + area2 - inter + + iou = inter / union + return iou, union + + +def generalized_box_iou(boxes1, boxes2): + """ + Generalized IoU from https://giou.stanford.edu/ + + The boxes should be in [x0, y0, x1, y1] format + + Returns a [N, M] pairwise matrix, where N = len(boxes1) + and M = len(boxes2) + """ + # degenerate boxes gives inf / nan results + # so do an early check + assert (boxes1[:, 2:] >= boxes1[:, :2]).all() + assert (boxes2[:, 2:] >= boxes2[:, :2]).all() + iou, union = box_iou(boxes1, boxes2) + + lt = torch.min(boxes1[:, None, :2], boxes2[:, :2]) + rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:]) + + wh = (rb - lt).clamp(min=0) # [N,M,2] + area = wh[:, :, 0] * wh[:, :, 1] + + return iou - (area - union) / area + + +def masks_to_boxes(masks): + """Compute the bounding boxes around the provided masks + + The masks should be in format [N, H, W] where N is the number of masks, (H, W) are the spatial dimensions. + + Returns a [N, 4] tensors, with the boxes in xyxy format + """ + if masks.numel() == 0: + return torch.zeros((0, 4), device=masks.device) + + h, w = masks.shape[-2:] + + y = torch.arange(0, h, dtype=torch.float) + x = torch.arange(0, w, dtype=torch.float) + y, x = torch.meshgrid(y, x) + + x_mask = (masks * x.unsqueeze(0)) + x_max = x_mask.flatten(1).max(-1)[0] + x_min = x_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0] + + y_mask = (masks * y.unsqueeze(0)) + y_max = y_mask.flatten(1).max(-1)[0] + y_min = y_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0] + + return torch.stack([x_min, y_min, x_max, y_max], 1) diff --git a/utils/dir.py b/utils/dir.py new file mode 100644 index 0000000..6c1336c --- /dev/null +++ b/utils/dir.py @@ -0,0 +1,18 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +import os +import sys + +def make_folder(folder_name): + if not os.path.exists(folder_name): + os.makedirs(folder_name) + +def add_pypath(path): + if path not in sys.path: + sys.path.insert(0, path) + diff --git a/utils/misc.py b/utils/misc.py new file mode 100644 index 0000000..b71138e --- /dev/null +++ b/utils/misc.py @@ -0,0 +1,431 @@ +# Copyright (c) 2020 Graz University of Technology All rights reserved. + +import torch +import torch.distributed as dist +from torch import Tensor +from typing import Optional, List +import numpy as np +from config import cfg +from scipy.optimize import linear_sum_assignment + +# needed due to empty tensor bug in pytorch and torchvision 0.5 +import torchvision + + +class NestedTensor(object): + def __init__(self, tensors, mask: Optional[Tensor]): + self.tensors = tensors + self.mask = mask + + def to(self, device): + # type: (Device) -> NestedTensor # noqa + cast_tensor = self.tensors.to(device) + mask = self.mask + if mask is not None: + assert mask is not None + cast_mask = mask.to(device) + else: + cast_mask = None + return NestedTensor(cast_tensor, cast_mask) + + def decompose(self): + return self.tensors, self.mask + + def __repr__(self): + return str(self.tensors) + +def nested_tensor_from_tensor_list(tensor_list: List[Tensor]): + # TODO make this more general + if tensor_list[0].ndim == 3: + if torchvision._is_tracing(): + # nested_tensor_from_tensor_list() does not export well to ONNX + # call _onnx_nested_tensor_from_tensor_list() instead + return _onnx_nested_tensor_from_tensor_list(tensor_list) + + # TODO make it support different-sized images + max_size = _max_by_axis([list(img.shape) for img in tensor_list]) + # min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list])) + batch_shape = [len(tensor_list)] + max_size + b, c, h, w = batch_shape + dtype = tensor_list[0].dtype + device = tensor_list[0].device + tensor = torch.zeros(batch_shape, dtype=dtype, device=device) + mask = torch.ones((b, h, w), dtype=torch.bool, device=device) + for img, pad_img, m in zip(tensor_list, tensor, mask): + pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) + m[: img.shape[1], :img.shape[2]] = False + else: + raise ValueError('not supported') + return NestedTensor(tensor, mask) + +def _max_by_axis(the_list): + # type: (List[List[int]]) -> List[int] + maxes = the_list[0] + for sublist in the_list[1:]: + for index, item in enumerate(sublist): + maxes[index] = max(maxes[index], item) + return maxes + +def collate_fn(batch): + batch = list(zip(*batch)) + batch[0] = nested_tensor_from_tensor_list(batch[0]) + return tuple(batch) + +# _onnx_nested_tensor_from_tensor_list() is an implementation of +# nested_tensor_from_tensor_list() that is supported by ONNX tracing. +@torch.jit.unused +def _onnx_nested_tensor_from_tensor_list(tensor_list: List[Tensor]) -> NestedTensor: + max_size = [] + for i in range(tensor_list[0].dim()): + max_size_i = torch.max(torch.stack([img.shape[i] for img in tensor_list]).to(torch.float32)).to(torch.int64) + max_size.append(max_size_i) + max_size = tuple(max_size) + + # work around for + # pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) + # m[: img.shape[1], :img.shape[2]] = False + # which is not yet supported in onnx + padded_imgs = [] + padded_masks = [] + for img in tensor_list: + padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))] + padded_img = torch.nn.functional.pad(img, (0, padding[2], 0, padding[1], 0, padding[0])) + padded_imgs.append(padded_img) + + m = torch.zeros_like(img[0], dtype=torch.int, device=img.device) + padded_mask = torch.nn.functional.pad(m, (0, padding[2], 0, padding[1]), "constant", 1) + padded_masks.append(padded_mask.to(torch.bool)) + + tensor = torch.stack(padded_imgs) + mask = torch.stack(padded_masks) + + return NestedTensor(tensor, mask=mask) + +def hungarian_match_2djoints(joint_loc_gt_np, joint_loc_pred_np, joint_valid_np, mask, joint_class_pred): + mid_dist_th = 2 + # get associations based on hungarian algo. + gt_inds_batch = [] + row_inds_batch = [] + for i in range(joint_loc_pred_np.shape[0]): + # for invalid joints, make sure there locations are somewhere far off, so that any asso. with them will be invalid + joint_loc_gt_np[i][np.logical_not(joint_valid_np[i])] = np.array( + [cfg.input_img_shape[0], cfg.input_img_shape[0]]) * 2 + + # create the cost matrix. any cost > mid_dist_th is clipped. This helps in scenarios where cost matrix is tall + dist = np.linalg.norm(np.expand_dims(joint_loc_pred_np[i], 1) - np.expand_dims(joint_loc_gt_np[i], 0), + axis=2) # max_num_peaks x 21 + dist1 = dist[mask[i]] # remove the invalid guys + dist1[dist1 > mid_dist_th] = mid_dist_th + + # invoke the hungry hungarian + indices = linear_sum_assignment(dist1) + row_ind = indices[0] + asso_ind = indices[1] + + # 0 - invalid class, rest all classes follow the same order as in the joint_gt + gt_inds = asso_ind + 1 + + # if the any associations have a distance > mid_dist_th, then its not right + for ii in range(row_ind.shape[0]): + if dist1[row_ind[ii], asso_ind[ii]] >= mid_dist_th: + gt_inds[ii] = 0 + + # when cost matrix is tall, assign the false postives to 0 index + if row_ind.shape[0] < np.sum(mask[i]): + false_pos_row_inds = np.setdiff1d(np.arange(0, np.sum(mask[i])), row_ind) + assert false_pos_row_inds.shape[0] == (np.sum(mask[i]) - row_ind.shape[0]) + row_ind = np.concatenate([row_ind, false_pos_row_inds], axis=0) + gt_inds = np.concatenate([gt_inds, np.zeros((false_pos_row_inds.shape[0],))], axis=0) + + gt_inds_batch.append(gt_inds) + row_inds_batch.append(row_ind) + asso_inds_batch_list = gt_inds_batch + row_inds_batch_list = row_inds_batch + gt_inds_batch = np.concatenate(gt_inds_batch, axis=0).astype(np.int) + + gt_inds_batch = np.tile(np.expand_dims(gt_inds_batch, 0), [joint_class_pred.shape[0], 1]) # 6 x M + gt_inds_batch = np.reshape(gt_inds_batch, [-1]) # 6*M + + return gt_inds_batch, row_inds_batch_list, asso_inds_batch_list + + +def nearest_match_2djoints(joint_loc_gt_np, joint_loc_pred_np, joint_valid_np, mask, joint_class_pred, joints_3d_gt_np, + peak_joints_map_batch_np, target_obj_kps_heatmap_np, obj_kps_coord_gt_np, obj_kps_3d_gt_np): + mid_dist_th = 3 + gt_inds_batch = [] + row_inds_batch = [] + bs = joint_loc_gt_np.shape[0] + + for i in range(bs): + # for invalid joints, make sure there locations are somewhere far off, so that any asso. with them will be invalid + joint_loc_gt_np[i][np.logical_not(joint_valid_np[i])] = np.array( + [cfg.input_img_shape[0], cfg.input_img_shape[0]]) * 2 + + gt_inds = np.zeros((np.sum(mask[i]))) + for j in np.arange(0, cfg.max_num_peaks): + if mask[i,j] == 0: + continue + + curr_joint_loc_pred = joint_loc_pred_np[i, j] + if cfg.has_object: + if peak_joints_map_batch_np[i,j] == cfg.obj_cls_index: + + if target_obj_kps_heatmap_np[i, int(curr_joint_loc_pred[1]), int(curr_joint_loc_pred[0])] > 100: + gt_inds[j] = cfg.obj_cls_index + else: + gt_inds[j] = 0 + continue + + + dist = np.linalg.norm(curr_joint_loc_pred - joint_loc_gt_np[i], axis=-1) + closest_pts_mask = dist <= mid_dist_th + if not np.any(closest_pts_mask): + continue + foreground_pt_ind = np.argmin(joints_3d_gt_np[i,:,2][closest_pts_mask]) + gt_inds[j] = np.where(closest_pts_mask)[0][foreground_pt_ind]+1 + + gt_inds_batch.append(gt_inds) + row_inds_batch.append(np.arange(0, np.sum(mask[i]))) + + asso_inds_batch_list = gt_inds_batch + row_inds_batch_list = row_inds_batch + gt_inds_batch = np.concatenate(gt_inds_batch, axis=0).astype(np.int) + + gt_inds_batch = np.tile(np.expand_dims(gt_inds_batch, 0), [joint_class_pred.shape[0], 1]) # 6 x M + gt_inds_batch = np.reshape(gt_inds_batch, [-1]) # 6*M + + return gt_inds_batch, row_inds_batch_list, asso_inds_batch_list + +def binary(x, bits): + mask = 2**torch.arange(bits).to(x.device, x.dtype) + return x.unsqueeze(-1).bitwise_and(mask).ne(0).to(torch.float32) + +def get_tgt_mask(): + if cfg.predict_type == 'angles': + tgt_mask = torch.zeros((cfg.num_queries, cfg.num_queries), dtype=torch.bool) + # global rot + tgt_mask[0, :] = True + tgt_mask[0, 0] = False + tgt_mask[cfg.num_joint_queries_per_hand, :] = True + tgt_mask[cfg.num_joint_queries_per_hand, cfg.num_joint_queries_per_hand] = False + + # fingers + for i in range(5): + # right hand + s = 3 * i + 1 + e = 3 * i + 4 + tgt_mask[s:e, :] = True + tgt_mask[s:e, s:e] = False + # left hand + s = s + cfg.num_joint_queries_per_hand + e = e + cfg.num_joint_queries_per_hand + tgt_mask[s:e, :] = True + tgt_mask[s:e, s:e] = False + + # trans and shape + tgt_mask[cfg.shape_indx, :] = True + tgt_mask[cfg.shape_indx, cfg.shape_indx] = False + + + if cfg.has_object: + # make hand queries depend on object + tgt_mask[:2 * cfg.num_joint_queries_per_hand, cfg.obj_rot_indx] = False + tgt_mask[:2 * cfg.num_joint_queries_per_hand, cfg.obj_trans_indx] = False + tgt_mask[cfg.shape_indx, cfg.obj_rot_indx] = False + tgt_mask[cfg.shape_indx, cfg.obj_trans_indx] = False + + elif cfg.predict_type == 'vectors': + tgt_mask = torch.zeros((cfg.num_queries, cfg.num_queries), dtype=torch.bool) + # fingers + for i in range(5): + # right hand + s = 4 * i + 0 + e = 4 * i + 4 + tgt_mask[s:e, :] = True + tgt_mask[s:e, s:e] = False + # left hand + s = s + cfg.num_joint_queries_per_hand + e = e + cfg.num_joint_queries_per_hand + tgt_mask[s:e, :] = True + tgt_mask[s:e, s:e] = False + # trans and shape + tgt_mask[cfg.shape_indx, :] = True + tgt_mask[cfg.shape_indx, cfg.shape_indx] = False + + + else: + raise NotImplementedError + + + return tgt_mask + +def get_src_memory_mask(peak_joints_map_batch): + src_mask_list = [] + memory_mask_list = [] + for i in range(peak_joints_map_batch.shape[0]): + mask = torch.zeros((cfg.max_num_peaks, cfg.max_num_peaks), dtype=torch.bool) + joint_locs_mask = torch.logical_and(peak_joints_map_batch[i] != cfg.obj_cls_index, peak_joints_map_batch[i] != 0) + mask[joint_locs_mask] = True + mask[joint_locs_mask, joint_locs_mask] = False + + src_mask_list.append(mask) + + memory_mask = torch.zeros((cfg.num_queries, cfg.max_num_peaks), dtype=torch.bool) + if np.sum((peak_joints_map_batch[i]==cfg.obj_cls_index).cpu().numpy()) > 0: + memory_mask[cfg.obj_rot_indx] = peak_joints_map_batch[i]!=cfg.obj_cls_index + memory_mask_list.append(memory_mask) + + src_mask = torch.stack(src_mask_list,dim=0).unsqueeze(1).repeat(1,cfg.nheads,1,1).view(-1, cfg.max_num_peaks, cfg.max_num_peaks) + memory_mask_list = torch.stack(memory_mask_list, dim=0).unsqueeze(1).repeat(1,cfg.nheads,1,1).view(-1, cfg.num_queries, cfg.max_num_peaks) + + return src_mask, memory_mask_list + +def get_root_rel_from_parent_rel_depths(dep): + joint_recon_order = [3, 2, 1, 0, + 7, 6, 5, 4, + 11, 10, 9, 8, + 15, 14, 13, 12, + 19, 18, 17, 16] + + dep_root = [] + for j in range(5): + for i in range(4): + if i == 0: + dep_root.append(dep[joint_recon_order[j*4+i]]) + else: + new_dep = dep[joint_recon_order[j*4+i]] + dep_root[-1] + dep_root.append(new_dep) + + dep_root_reorder = np.array([dep_root[i] for i in joint_recon_order]+[0]) + return dep_root_reorder + +def my_print(string, f=None): + print(string) + if f is not None: + f.write(string+'\n') + +def batch_rodrigues( + rot_vecs: Tensor, + epsilon: float = 1e-8, +) -> Tensor: + ''' Calculates the rotation matrices for a batch of rotation vectors + Parameters + ---------- + rot_vecs: torch.tensor Nx3 + array of N axis-angle vectors + Returns + ------- + R: torch.tensor Nx3x3 + The rotation matrices for the given axis-angle parameters + ''' + + batch_size = rot_vecs.shape[0] + device, dtype = rot_vecs.device, rot_vecs.dtype + + angle = torch.norm(rot_vecs + 1e-8, dim=1, keepdim=True) + rot_dir = rot_vecs / angle + + cos = torch.unsqueeze(torch.cos(angle), dim=1) + sin = torch.unsqueeze(torch.sin(angle), dim=1) + + # Bx1 arrays + rx, ry, rz = torch.split(rot_dir, 1, dim=1) + K = torch.zeros((batch_size, 3, 3), dtype=dtype, device=device) + + zeros = torch.zeros((batch_size, 1), dtype=dtype, device=device) + K = torch.cat([zeros, -rz, ry, rz, zeros, -rx, -ry, rx, zeros], dim=1) \ + .view((batch_size, 3, 3)) + + ident = torch.eye(3, dtype=dtype, device=device).unsqueeze(dim=0) + rot_mat = ident + sin * K + (1 - cos) * torch.bmm(K, K) + return rot_mat + +def get_valid_front_face_from_bary(bary, verts, faces_in): + # given barycentric coordinates for each triangle in the mesh, find which frontmost triangle the point belongs to + ''' + + :param bary: N x M x F x 3 + :param verts: N x V x 3 + :param faces_in: F x 3 + :return: + ''' + faces = np.tile(np.expand_dims(faces_in, 0), (verts.shape[0], 1, 1)) # N x F x 3 + for i in range(verts.shape[0]): + faces[i] += verts.shape[1]*i + tri_verts = verts.reshape(-1, 3)[faces.reshape(-1)].reshape(verts.shape[0], faces_in.shape[0], 3, + 3) # N x F x 3 x 3 + mean_tri_dep = torch.mean(tri_verts[:,:,:,2], dim=2) # N x F + mean_tri_dep = mean_tri_dep.unsqueeze(1).repeat(1, bary.shape[1], 1) # N x M x F + + inside_pts = torch.logical_and(bary[:, :, :, 0] >= 0, bary[:, :, :, 0] <= 1) + inside_pts = torch.logical_and(inside_pts, bary[:, :, :, 1] >= 0) + inside_pts = torch.logical_and(inside_pts, bary[:, :, :, 1] <= 1) + inside_pts = torch.logical_and(inside_pts, bary[:, :, :, 2] >= 0) + inside_pts = torch.logical_and(inside_pts, bary[:, :, :, 2] <= 1) # N x M x F + + mean_tri_dep[torch.logical_not(inside_pts)] = float('inf') # N x M x F + min_val, hit_tri_ind = torch.min(mean_tri_dep, dim=2) # N x M + valid_hit_tri_ind = min_val != float('inf') # N x M + + hit_tri_verts = torch.gather(tri_verts, 1, hit_tri_ind[:,:,None,None].repeat(1,1,3,3)) # N x M x 3 x 3 + hit_tri_center = torch.mean(hit_tri_verts, dim=2) # N x M x 3 + + return hit_tri_center, hit_tri_ind, valid_hit_tri_ind + +def get_mesh_contacts(contact_pos_pred, vert, faces_in, cam_param): + ''' + + :param contact_pos_pred: N x M x 2 + :param vert: N x V x 3 + :param faces_in: F x 3 + :param cam_param: N x 3 + :return: + ''' + + contact_pos_plane = torch.cat([contact_pos_pred, + torch.zeros((contact_pos_pred.shape[0], contact_pos_pred.shape[1], 1)).to(contact_pos_pred.device)], dim=2) # N x M x 3 + vert_plane = vert[:,:,:2]*cam_param.unsqueeze(1)[:,:,:1] + cam_param.unsqueeze(1)[:,:,1:] # N x V x 2 + vert_plane = torch.cat([vert_plane, torch.zeros((vert_plane.shape[0], vert_plane.shape[1], 1)).to(vert_plane.device)], dim=2) # N x V x 3 + + bary_points = get_barycentric_points_from_contained_points(contact_pos_plane, vert_plane, faces_in) # N x M x F x 3 + hit_tri_center, hit_tri_ind, valid_hit_tri_ind = get_valid_front_face_from_bary(bary_points, vert, faces_in) + + return hit_tri_center, hit_tri_ind, valid_hit_tri_ind + + + + +def get_barycentric_points_from_contained_points(points, verts, faces_in): + # give a set of points on the surface of the mesh, get their barycentric coordinates for each triangle in face + # http://gamedev.stackexchange.com/questions/23743/whats-the-most-efficient-way-to-find-barycentric-coordinates + ''' + + :param points: N x M x 3 + :param verts: N x V x 3 + :param faces_in: F x 3 + :return: + ''' + faces = np.tile(np.expand_dims(faces_in,0), (verts.shape[0],1,1)) # N x F x 3 + for i in range(verts.shape[0]): + faces[i] += verts.shape[1]*i + tri_verts = verts.reshape(-1,3)[faces.reshape(-1)].reshape(verts.shape[0], faces_in.shape[0], 3, 3) # N x F x 3 x 3 + + a = tri_verts[:, :, 0, :] # N x F x 3 + b = tri_verts[:, :, 1, :] # N x F x 3 + c = tri_verts[:, :, 2, :] # N x F x 3 + v0 = (b - a).unsqueeze(1).repeat(1,points.shape[1],1,1) # N x M x F x 3 + v1 = (c - a).unsqueeze(1).repeat(1,points.shape[1],1,1) # N x M x F x 3 + v2 = points.unsqueeze(2) - a.unsqueeze(1) # N x M x F x 3 + d00 = torch.sum(v0 * v0, dim=3) # N x M x F + d01 = torch.sum(v0 * v1, dim=3) + d11 = torch.sum(v1 * v1, dim=3) + d20 = torch.sum(v2 * v0, dim=3) + d21 = torch.sum(v2 * v1, dim=3) + denom = d00 * d11 - d01 * d01 # N x M x F + v = (d11 * d20 - d01 * d21) / denom # N x M x F + w = (d00 * d21 - d01 * d20) / denom # N x M x F + u = 1 - v - w # N x M x F + + bary = torch.stack([u,v,w], dim=3) # N x M x F x 3 + + return bary diff --git a/utils/miscdetr.py b/utils/miscdetr.py new file mode 100644 index 0000000..710c901 --- /dev/null +++ b/utils/miscdetr.py @@ -0,0 +1,587 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +""" +Misc functions, including distributed helpers. + +Mostly copy-paste from torchvision references. +""" +import os +import random +import subprocess +import time +from collections import OrderedDict, defaultdict, deque +import datetime +import pickle +from typing import Optional, List + +import json, time +import numpy as np +import torch +import torch.distributed as dist +from torch import Tensor + +import colorsys + +# needed due to empty tensor bug in pytorch and torchvision 0.5 +import torchvision +__torchvision_need_compat_flag = float(torchvision.__version__.split('.')[1]) < 7 +if __torchvision_need_compat_flag: + from torchvision.ops import _new_empty_tensor + from torchvision.ops.misc import _output_size + + +class SmoothedValue(object): + """Track a series of values and provide access to smoothed values over a + window or the global series average. + """ + + def __init__(self, window_size=20, fmt=None): + if fmt is None: + fmt = "{median:.4f} ({global_avg:.4f})" + self.deque = deque(maxlen=window_size) + self.total = 0.0 + self.count = 0 + self.fmt = fmt + + def update(self, value, n=1): + self.deque.append(value) + self.count += n + self.total += value * n + + def synchronize_between_processes(self): + """ + Warning: does not synchronize the deque! + """ + if not is_dist_avail_and_initialized(): + return + t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') + dist.barrier() + dist.all_reduce(t) + t = t.tolist() + self.count = int(t[0]) + self.total = t[1] + + @property + def median(self): + d = torch.tensor(list(self.deque)) + if d.shape[0] == 0: + return 0 + return d.median().item() + + @property + def avg(self): + d = torch.tensor(list(self.deque), dtype=torch.float32) + return d.mean().item() + + @property + def global_avg(self): + return self.total / self.count + + @property + def max(self): + return max(self.deque) + + @property + def value(self): + return self.deque[-1] + + def __str__(self): + return self.fmt.format( + median=self.median, + avg=self.avg, + global_avg=self.global_avg, + max=self.max, + value=self.value) + + +def all_gather(data): + """ + Run all_gather on arbitrary picklable data (not necessarily tensors) + Args: + data: any picklable object + Returns: + list[data]: list of data gathered from each rank + """ + world_size = get_world_size() + if world_size == 1: + return [data] + + # serialized to a Tensor + buffer = pickle.dumps(data) + storage = torch.ByteStorage.from_buffer(buffer) + tensor = torch.ByteTensor(storage).to("cuda") + + # obtain Tensor size of each rank + local_size = torch.tensor([tensor.numel()], device="cuda") + size_list = [torch.tensor([0], device="cuda") for _ in range(world_size)] + dist.all_gather(size_list, local_size) + size_list = [int(size.item()) for size in size_list] + max_size = max(size_list) + + # receiving Tensor from all ranks + # we pad the tensor because torch all_gather does not support + # gathering tensors of different shapes + tensor_list = [] + for _ in size_list: + tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device="cuda")) + if local_size != max_size: + padding = torch.empty(size=(max_size - local_size,), dtype=torch.uint8, device="cuda") + tensor = torch.cat((tensor, padding), dim=0) + dist.all_gather(tensor_list, tensor) + + data_list = [] + for size, tensor in zip(size_list, tensor_list): + buffer = tensor.cpu().numpy().tobytes()[:size] + data_list.append(pickle.loads(buffer)) + + return data_list + + +def reduce_dict(input_dict, average=True): + """ + Args: + input_dict (dict): all the values will be reduced + average (bool): whether to do average or sum + Reduce the values in the dictionary from all processes so that all processes + have the averaged results. Returns a dict with the same fields as + input_dict, after reduction. + """ + world_size = get_world_size() + if world_size < 2: + return input_dict + with torch.no_grad(): + names = [] + values = [] + # sort the keys so that they are consistent across processes + for k in sorted(input_dict.keys()): + names.append(k) + values.append(input_dict[k]) + values = torch.stack(values, dim=0) + dist.all_reduce(values) + if average: + values /= world_size + reduced_dict = {k: v for k, v in zip(names, values)} + return reduced_dict + + +class MetricLogger(object): + def __init__(self, delimiter="\t"): + self.meters = defaultdict(SmoothedValue) + self.delimiter = delimiter + + def update(self, **kwargs): + for k, v in kwargs.items(): + if isinstance(v, torch.Tensor): + v = v.item() + assert isinstance(v, (float, int)) + self.meters[k].update(v) + + def __getattr__(self, attr): + if attr in self.meters: + return self.meters[attr] + if attr in self.__dict__: + return self.__dict__[attr] + raise AttributeError("'{}' object has no attribute '{}'".format( + type(self).__name__, attr)) + + def __str__(self): + loss_str = [] + for name, meter in self.meters.items(): + # print(name, str(meter)) + # import ipdb;ipdb.set_trace() + if meter.count > 0: + loss_str.append( + "{}: {}".format(name, str(meter)) + ) + return self.delimiter.join(loss_str) + + def synchronize_between_processes(self): + for meter in self.meters.values(): + meter.synchronize_between_processes() + + def add_meter(self, name, meter): + self.meters[name] = meter + + def log_every(self, iterable, print_freq, header=None, logger=None): + if logger is None: + print_func = print + else: + print_func = logger.info + + i = 0 + if not header: + header = '' + start_time = time.time() + end = time.time() + iter_time = SmoothedValue(fmt='{avg:.4f}') + data_time = SmoothedValue(fmt='{avg:.4f}') + space_fmt = ':' + str(len(str(len(iterable)))) + 'd' + if torch.cuda.is_available(): + log_msg = self.delimiter.join([ + header, + '[{0' + space_fmt + '}/{1}]', + 'eta: {eta}', + '{meters}', + 'time: {time}', + 'data: {data}', + 'max mem: {memory:.0f}' + ]) + else: + log_msg = self.delimiter.join([ + header, + '[{0' + space_fmt + '}/{1}]', + 'eta: {eta}', + '{meters}', + 'time: {time}', + 'data: {data}' + ]) + MB = 1024.0 * 1024.0 + for obj in iterable: + data_time.update(time.time() - end) + yield obj + # import ipdb; ipdb.set_trace() + iter_time.update(time.time() - end) + if i % print_freq == 0 or i == len(iterable) - 1: + eta_seconds = iter_time.global_avg * (len(iterable) - i) + eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) + if torch.cuda.is_available(): + print_func(log_msg.format( + i, len(iterable), eta=eta_string, + meters=str(self), + time=str(iter_time), data=str(data_time), + memory=torch.cuda.max_memory_allocated() / MB)) + else: + print_func(log_msg.format( + i, len(iterable), eta=eta_string, + meters=str(self), + time=str(iter_time), data=str(data_time))) + i += 1 + end = time.time() + total_time = time.time() - start_time + total_time_str = str(datetime.timedelta(seconds=int(total_time))) + print_func('{} Total time: {} ({:.4f} s / it)'.format( + header, total_time_str, total_time / len(iterable))) + + +def get_sha(): + cwd = os.path.dirname(os.path.abspath(__file__)) + + def _run(command): + return subprocess.check_output(command, cwd=cwd).decode('ascii').strip() + sha = 'N/A' + diff = "clean" + branch = 'N/A' + try: + sha = _run(['git', 'rev-parse', 'HEAD']) + subprocess.check_output(['git', 'diff'], cwd=cwd) + diff = _run(['git', 'diff-index', 'HEAD']) + diff = "has uncommited changes" if diff else "clean" + branch = _run(['git', 'rev-parse', '--abbrev-ref', 'HEAD']) + except Exception: + pass + message = f"sha: {sha}, status: {diff}, branch: {branch}" + return message + + +def collate_fn(batch): + # import ipdb; ipdb.set_trace() + batch = list(zip(*batch)) + batch[0] = nested_tensor_from_tensor_list(batch[0]) + return tuple(batch) + + +def _max_by_axis(the_list): + # type: (List[List[int]]) -> List[int] + maxes = the_list[0] + for sublist in the_list[1:]: + for index, item in enumerate(sublist): + maxes[index] = max(maxes[index], item) + return maxes + + +class NestedTensor(object): + def __init__(self, tensors, mask: Optional[Tensor]): + self.tensors = tensors + self.mask = mask + if mask == 'auto': + self.mask = torch.zeros_like(tensors).to(tensors.device) + if self.mask.dim() == 3: + self.mask = self.mask.sum(0).to(bool) + elif self.mask.dim() == 4: + self.mask = self.mask.sum(1).to(bool) + else: + raise ValueError("tensors dim must be 3 or 4 but {}({})".format(self.tensors.dim(), self.tensors.shape)) + + def imgsize(self): + res = [] + for i in range(self.tensors.shape[0]): + mask = self.mask[i] + maxH = (~mask).sum(0).max() + maxW = (~mask).sum(1).max() + res.append(torch.Tensor([maxH, maxW])) + return res + + def to(self, device): + ## type: (Devicec) -> NestedTensor # noqa + cast_tensor = self.tensors.to(device) + mask = self.mask + if mask is not None: + assert mask is not None + cast_mask = mask.to(device) + else: + cast_mask = None + return NestedTensor(cast_tensor, cast_mask) + + def to_img_list_single(self, tensor, mask): + assert tensor.dim() == 3, "dim of tensor should be 3 but {}".format(tensor.dim()) + maxH = (~mask).sum(0).max() + maxW = (~mask).sum(1).max() + img = tensor[:, :maxH, :maxW] + return img + + def to_img_list(self): + """remove the padding and convert to img list + + Returns: + [type]: [description] + """ + if self.tensors.dim() == 3: + return self.to_img_list_single(self.tensors, self.mask) + else: + res = [] + for i in range(self.tensors.shape[0]): + tensor_i = self.tensors[i] + mask_i = self.mask[i] + res.append(self.to_img_list_single(tensor_i, mask_i)) + return res + + @property + def device(self): + return self.tensors.device + + def decompose(self): + return self.tensors, self.mask + + def __repr__(self): + return str(self.tensors) + + @property + def shape(self): + return { + 'tensors.shape': self.tensors.shape, + 'mask.shape': self.mask.shape + } + + +def nested_tensor_from_tensor_list(tensor_list: List[Tensor]): + # TODO make this more general + if tensor_list[0].ndim == 3: + if torchvision._is_tracing(): + # nested_tensor_from_tensor_list() does not export well to ONNX + # call _onnx_nested_tensor_from_tensor_list() instead + return _onnx_nested_tensor_from_tensor_list(tensor_list) + + # TODO make it support different-sized images + max_size = _max_by_axis([list(img.shape) for img in tensor_list]) + # min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list])) + batch_shape = [len(tensor_list)] + max_size + b, c, h, w = batch_shape + dtype = tensor_list[0].dtype + device = tensor_list[0].device + tensor = torch.zeros(batch_shape, dtype=dtype, device=device) + mask = torch.ones((b, h, w), dtype=torch.bool, device=device) + for img, pad_img, m in zip(tensor_list, tensor, mask): + pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) + m[: img.shape[1], :img.shape[2]] = False + else: + raise ValueError('not supported') + return NestedTensor(tensor, mask) + + +# _onnx_nested_tensor_from_tensor_list() is an implementation of +# nested_tensor_from_tensor_list() that is supported by ONNX tracing. +@torch.jit.unused +def _onnx_nested_tensor_from_tensor_list(tensor_list: List[Tensor]) -> NestedTensor: + max_size = [] + for i in range(tensor_list[0].dim()): + max_size_i = torch.max(torch.stack([img.shape[i] for img in tensor_list]).to(torch.float32)).to(torch.int64) + max_size.append(max_size_i) + max_size = tuple(max_size) + + # work around for + # pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) + # m[: img.shape[1], :img.shape[2]] = False + # which is not yet supported in onnx + padded_imgs = [] + padded_masks = [] + for img in tensor_list: + padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))] + padded_img = torch.nn.functional.pad(img, (0, padding[2], 0, padding[1], 0, padding[0])) + padded_imgs.append(padded_img) + + m = torch.zeros_like(img[0], dtype=torch.int, device=img.device) + padded_mask = torch.nn.functional.pad(m, (0, padding[2], 0, padding[1]), "constant", 1) + padded_masks.append(padded_mask.to(torch.bool)) + + tensor = torch.stack(padded_imgs) + mask = torch.stack(padded_masks) + + return NestedTensor(tensor, mask=mask) + + +def setup_for_distributed(is_master): + """ + This function disables printing when not in master process + """ + import builtins as __builtin__ + builtin_print = __builtin__.print + + def print(*args, **kwargs): + force = kwargs.pop('force', False) + if is_master or force: + builtin_print(*args, **kwargs) + + __builtin__.print = print + + +def is_dist_avail_and_initialized(): + if not dist.is_available(): + return False + if not dist.is_initialized(): + return False + return True + + +def get_world_size(): + if not is_dist_avail_and_initialized(): + return 1 + return dist.get_world_size() + + +def get_rank(): + if not is_dist_avail_and_initialized(): + return 0 + return dist.get_rank() + + +def is_main_process(): + return get_rank() == 0 + + +def save_on_master(*args, **kwargs): + if is_main_process(): + torch.save(*args, **kwargs) + + +def init_distributed_mode(args): + if 'WORLD_SIZE' in os.environ and os.environ['WORLD_SIZE'] != '': # 'RANK' in os.environ and + # args.rank = int(os.environ["RANK"]) + # args.world_size = int(os.environ['WORLD_SIZE']) + # args.gpu = args.local_rank = int(os.environ['LOCAL_RANK']) + + # launch by torch.distributed.launch + # Single node + # python -m torch.distributed.launch --nproc_per_node=8 main.py --world-size 1 --rank 0 ... + # Multi nodes + # python -m torch.distributed.launch --nproc_per_node=8 main.py --world-size 2 --rank 0 --dist-url 'tcp://IP_OF_NODE0:FREEPORT' ... + # python -m torch.distributed.launch --nproc_per_node=8 main.py --world-size 2 --rank 1 --dist-url 'tcp://IP_OF_NODE0:FREEPORT' ... + + local_world_size = int(os.environ['WORLD_SIZE']) + args.world_size = args.world_size * local_world_size + args.gpu = args.local_rank = int(os.environ['LOCAL_RANK']) + args.rank = args.rank * local_world_size + args.local_rank + print('world size: {}, rank: {}, local rank: {}'.format(args.world_size, args.rank, args.local_rank)) + print(json.dumps(dict(os.environ), indent=2)) + elif 'SLURM_PROCID' in os.environ: + args.rank = int(os.environ['SLURM_PROCID']) + args.gpu = args.local_rank = int(os.environ['SLURM_LOCALID']) + args.world_size = int(os.environ['SLURM_NPROCS']) + + print('world size: {}, world rank: {}, local rank: {}, device_count: {}'.format(args.world_size, args.rank, args.local_rank, torch.cuda.device_count())) + else: + print('Not using distributed mode') + args.distributed = False + args.world_size = 1 + args.rank = 0 + args.local_rank = 0 + return + + print("world_size:{} rank:{} local_rank:{}".format(args.world_size, args.rank, args.local_rank)) + args.distributed = True + torch.cuda.set_device(args.local_rank) + args.dist_backend = 'nccl' + print('| distributed init (rank {}): {}'.format(args.rank, args.dist_url), flush=True) + torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, + world_size=args.world_size, rank=args.rank) + print("Before torch.distributed.barrier()") + torch.distributed.barrier() + print("End torch.distributed.barrier()") + setup_for_distributed(args.rank == 0) + + +@torch.no_grad() +def accuracy(output, target, topk=(1,)): + """Computes the precision@k for the specified values of k""" + if target.numel() == 0: + return [torch.zeros([], device=output.device)] + maxk = max(topk) + batch_size = target.size(0) + + _, pred = output.topk(maxk, 1, True, True) + pred = pred.t() + correct = pred.eq(target.view(1, -1).expand_as(pred)) + + res = [] + for k in topk: + correct_k = correct[:k].view(-1).float().sum(0) + res.append(correct_k.mul_(100.0 / batch_size)) + return res + + +def interpolate(input, size=None, scale_factor=None, mode="nearest", align_corners=None): + # type: (Tensor, Optional[List[int]], Optional[float], str, Optional[bool]) -> Tensor + """ + Equivalent to nn.functional.interpolate, but with support for empty batch sizes. + This will eventually be supported natively by PyTorch, and this + class can go away. + """ + if __torchvision_need_compat_flag < 0.7: + if input.numel() > 0: + return torch.nn.functional.interpolate( + input, size, scale_factor, mode, align_corners + ) + + output_shape = _output_size(2, input, size, scale_factor) + output_shape = list(input.shape[:-2]) + list(output_shape) + return _new_empty_tensor(input, output_shape) + else: + return torchvision.ops.misc.interpolate(input, size, scale_factor, mode, align_corners) + + + +class color_sys(): + def __init__(self, num_colors) -> None: + self.num_colors = num_colors + colors=[] + for i in np.arange(0., 360., 360. / num_colors): + hue = i/360. + lightness = (50 + np.random.rand() * 10)/100. + saturation = (90 + np.random.rand() * 10)/100. + colors.append(tuple([int(j*255) for j in colorsys.hls_to_rgb(hue, lightness, saturation)])) + self.colors = colors + + def __call__(self, idx): + return self.colors[idx] + +def inverse_sigmoid(x, eps=1e-3): + x = x.clamp(min=0, max=1) + x1 = x.clamp(min=eps) + x2 = (1 - x).clamp(min=eps) + return torch.log(x1/x2) + +def clean_state_dict(state_dict): + new_state_dict = OrderedDict() + for k, v in state_dict.items(): + if k[:7] == 'module.': + k = k[7:] # remove `module.` + new_state_dict[k] = v + return new_state_dict \ No newline at end of file diff --git a/utils/preprocessing.py b/utils/preprocessing.py new file mode 100644 index 0000000..f6ba381 --- /dev/null +++ b/utils/preprocessing.py @@ -0,0 +1,222 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +import os +import cv2 +import numpy as np +from config import cfg +import random +import math + + +def load_img(path, order='RGB'): + + # load + img = cv2.imread(path, cv2.IMREAD_COLOR | cv2.IMREAD_IGNORE_ORIENTATION) + if not isinstance(img, np.ndarray): + raise IOError("Fail to read %s" % path) + + if order=='RGB': + img = img[:,:,::-1].copy() + + img = img.astype(np.float32) + return img + +def load_skeleton(path, joint_num): + + # load joint info (name, parent_id) + skeleton = [{} for _ in range(joint_num)] + with open(path) as fp: + for line in fp: + if line[0] == '#': continue + splitted = line.split(' ') + joint_name, joint_id, joint_parent_id = splitted + joint_id, joint_parent_id = int(joint_id), int(joint_parent_id) + skeleton[joint_id]['name'] = joint_name + skeleton[joint_id]['parent_id'] = joint_parent_id + # save child_id + for i in range(len(skeleton)): + joint_child_id = [] + for j in range(len(skeleton)): + if skeleton[j]['parent_id'] == i: + joint_child_id.append(j) + skeleton[i]['child_id'] = joint_child_id + + return skeleton + +def get_aug_config(): + trans_factor = 0.15 + scale_factor = 0.25 + rot_factor = 45 + color_factor = 0.2 + + trans = [np.random.uniform(-trans_factor, trans_factor), np.random.uniform(-trans_factor, trans_factor)] + scale = np.clip(np.random.randn(), -1.0, 1.0) * scale_factor + 1.0 + rot = np.clip(np.random.randn(), -2.0, + 2.0) * rot_factor if random.random() <= 0.6 else 0 + do_flip = random.random() <= 0.5 + c_up = 1.0 + color_factor + c_low = 1.0 - color_factor + color_scale = np.array([random.uniform(c_low, c_up), random.uniform(c_low, c_up), random.uniform(c_low, c_up)]) + + return trans, scale, rot, do_flip, color_scale + +def augmentation(img, bbox, joint_coord, joint_valid, hand_type, mode, joint_type): + img = img.copy(); + joint_coord = joint_coord.copy(); + hand_type = hand_type.copy(); + + original_img_shape = img.shape + joint_num = len(joint_coord) + + if mode == 'train': + trans, scale, rot, do_flip, color_scale = get_aug_config() + + else: + trans, scale, rot, do_flip, color_scale = [0,0], 1.0, 0.0, False, np.array([1,1,1]) + + bbox[0] = bbox[0] + bbox[2] * trans[0] + bbox[1] = bbox[1] + bbox[3] * trans[1] + img, trans, inv_trans = generate_patch_image(img, bbox, do_flip, scale, rot, cfg.input_img_shape) + img = np.clip(img * color_scale[None,None,:], 0, 255) + + if do_flip: + joint_coord[:,0] = original_img_shape[1] - joint_coord[:,0] - 1 + joint_coord[joint_type['right']], joint_coord[joint_type['left']] = joint_coord[joint_type['left']].copy(), joint_coord[joint_type['right']].copy() + joint_valid[joint_type['right']], joint_valid[joint_type['left']] = joint_valid[joint_type['left']].copy(), joint_valid[joint_type['right']].copy() + hand_type[0], hand_type[1] = hand_type[1].copy(), hand_type[0].copy() + for i in range(joint_num): + joint_coord[i,:2] = trans_point2d(joint_coord[i,:2], trans) + + return img, joint_coord, joint_valid, hand_type, inv_trans + + +def transform_input_to_output_space(joint_coord, joint_valid, rel_root_depth, root_valid, root_joint_idx, joint_type): + # transform to output heatmap space + joint_coord = joint_coord.copy(); joint_valid = joint_valid.copy() + + joint_coord[:,0] = joint_coord[:,0] / cfg.input_img_shape[1] * cfg.output_hm_shape[2] + joint_coord[:,1] = joint_coord[:,1] / cfg.input_img_shape[0] * cfg.output_hm_shape[1] + + joint_coord[joint_type['right'],2] = joint_coord[joint_type['right'],2] - joint_coord[root_joint_idx['right'],2] + joint_coord[joint_type['left'],2] = joint_coord[joint_type['left'],2] - joint_coord[root_joint_idx['left'],2] + + joint_coord[:,2] = (joint_coord[:,2] / (cfg.bbox_3d_size/2) + 1)/2. * cfg.output_hm_shape[0] + + joint_valid = joint_valid * ((joint_coord[:,2] >= 0) * (joint_coord[:,2] < cfg.output_hm_shape[0])).astype(np.float32) + joint_valid = joint_valid * ((joint_coord[:,0] >= 0) * (joint_coord[:,0] < cfg.output_hm_shape[1])).astype(np.float32) + joint_valid = joint_valid * ((joint_coord[:,1] >= 0) * (joint_coord[:,1] < cfg.output_hm_shape[2])).astype(np.float32) + + rel_root_depth = (rel_root_depth / (cfg.bbox_3d_size_root/2) + 1)/2. * cfg.output_root_hm_shape + root_valid = root_valid * ((rel_root_depth >= 0) * (rel_root_depth < cfg.output_root_hm_shape)).astype(np.float32) + + return joint_coord, joint_valid, rel_root_depth, root_valid + + +def get_bbox(joint_img, joint_valid): + x_img = joint_img[:,0][joint_valid==1]; y_img = joint_img[:,1][joint_valid==1]; + xmin = min(x_img); ymin = min(y_img); xmax = max(x_img); ymax = max(y_img); + + x_center = (xmin+xmax)/2.; width = xmax-xmin; + xmin = x_center - 0.5*width*1.2 + xmax = x_center + 0.5*width*1.2 + + y_center = (ymin+ymax)/2.; height = ymax-ymin; + ymin = y_center - 0.5*height*1.2 + ymax = y_center + 0.5*height*1.2 + + bbox = np.array([xmin, ymin, xmax-xmin, ymax-ymin]).astype(np.float32) + return bbox + +def process_bbox(bbox, original_img_shape): + + # aspect ratio preserving bbox + w = bbox[2] + h = bbox[3] + c_x = bbox[0] + w/2. + c_y = bbox[1] + h/2. + aspect_ratio = cfg.input_img_shape[1]/cfg.input_img_shape[0] + if w > aspect_ratio * h: + h = w / aspect_ratio + elif w < aspect_ratio * h: + w = h * aspect_ratio + bbox[2] = w*1.25 + bbox[3] = h*1.25 + bbox[0] = c_x - bbox[2]/2. + bbox[1] = c_y - bbox[3]/2. + + return bbox + +def generate_patch_image(cvimg, bbox, do_flip, scale, rot, out_shape): + img = cvimg.copy() + img_height, img_width, img_channels = img.shape + + bb_c_x = float(bbox[0] + 0.5*bbox[2]) + bb_c_y = float(bbox[1] + 0.5*bbox[3]) + bb_width = float(bbox[2]) + bb_height = float(bbox[3]) + + if do_flip: + img = img[:, ::-1, :] + bb_c_x = img_width - bb_c_x - 1 + + trans = gen_trans_from_patch_cv(bb_c_x, bb_c_y, bb_width, bb_height, out_shape[1], out_shape[0], scale, rot) + img_patch = cv2.warpAffine(img, trans, (int(out_shape[1]), int(out_shape[0])), flags=cv2.INTER_LINEAR) + img_patch = img_patch.astype(np.float32) + inv_trans = gen_trans_from_patch_cv(bb_c_x, bb_c_y, bb_width, bb_height, out_shape[1], out_shape[0], scale, rot, inv=True) + + return img_patch, trans, inv_trans + +def rotate_2d(pt_2d, rot_rad): + x = pt_2d[0] + y = pt_2d[1] + sn, cs = np.sin(rot_rad), np.cos(rot_rad) + xx = x * cs - y * sn + yy = x * sn + y * cs + return np.array([xx, yy], dtype=np.float32) + +def gen_trans_from_patch_cv(c_x, c_y, src_width, src_height, dst_width, dst_height, scale, rot, inv=False): + # augment size with scale + src_w = src_width * scale + src_h = src_height * scale + src_center = np.array([c_x, c_y], dtype=np.float32) + + # augment rotation + rot_rad = np.pi * rot / 180 + src_downdir = rotate_2d(np.array([0, src_h * 0.5], dtype=np.float32), rot_rad) + src_rightdir = rotate_2d(np.array([src_w * 0.5, 0], dtype=np.float32), rot_rad) + + dst_w = dst_width + dst_h = dst_height + dst_center = np.array([dst_w * 0.5, dst_h * 0.5], dtype=np.float32) + dst_downdir = np.array([0, dst_h * 0.5], dtype=np.float32) + dst_rightdir = np.array([dst_w * 0.5, 0], dtype=np.float32) + + src = np.zeros((3, 2), dtype=np.float32) + src[0, :] = src_center + src[1, :] = src_center + src_downdir + src[2, :] = src_center + src_rightdir + + dst = np.zeros((3, 2), dtype=np.float32) + dst[0, :] = dst_center + dst[1, :] = dst_center + dst_downdir + dst[2, :] = dst_center + dst_rightdir + + if inv: + trans = cv2.getAffineTransform(np.float32(dst), np.float32(src)) + else: + trans = cv2.getAffineTransform(np.float32(src), np.float32(dst)) + + trans = trans.astype(np.float32) + return trans + +def trans_point2d(pt_2d, trans): + src_pt = np.array([pt_2d[0], pt_2d[1], 1.]).T + dst_pt = np.dot(trans, src_pt) + return dst_pt[0:2] + + diff --git a/utils/transforms.py b/utils/transforms.py new file mode 100644 index 0000000..2d86816 --- /dev/null +++ b/utils/transforms.py @@ -0,0 +1,59 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +import torch +import numpy as np + +def cam2pixel(cam_coord, f, c): + x = cam_coord[:, 0] / (cam_coord[:, 2] + 1e-8) * f[0] + c[0] + y = cam_coord[:, 1] / (cam_coord[:, 2] + 1e-8) * f[1] + c[1] + z = cam_coord[:, 2] + img_coord = np.concatenate((x[:,None], y[:,None], z[:,None]),1) + return img_coord + +def pixel2cam(pixel_coord, f, c): + x = (pixel_coord[:, 0] - c[0]) / f[0] * pixel_coord[:, 2] + y = (pixel_coord[:, 1] - c[1]) / f[1] * pixel_coord[:, 2] + z = pixel_coord[:, 2] + cam_coord = np.concatenate((x[:,None], y[:,None], z[:,None]),1) + return cam_coord + +def world2cam(world_coord, R, T): + cam_coord = np.dot(R, world_coord - T) + return cam_coord + +def multi_meshgrid(*args): + """ + Creates a meshgrid from possibly many + elements (instead of only 2). + Returns a nd tensor with as many dimensions + as there are arguments + """ + args = list(args) + template = [1 for _ in args] + for i in range(len(args)): + n = args[i].shape[0] + template_copy = template.copy() + template_copy[i] = n + args[i] = args[i].view(*template_copy) + # there will be some broadcast magic going on + return tuple(args) + + +def flip(tensor, dims): + if not isinstance(dims, (tuple, list)): + dims = [dims] + indices = [torch.arange(tensor.shape[dim] - 1, -1, -1, + dtype=torch.int64) for dim in dims] + multi_indices = multi_meshgrid(*indices) + final_indices = [slice(i) for i in tensor.shape] + for i, dim in enumerate(dims): + final_indices[dim] = multi_indices[i] + flipped = tensor[final_indices] + assert flipped.device == tensor.device + assert flipped.requires_grad == tensor.requires_grad + return flipped diff --git a/utils/vis.py b/utils/vis.py new file mode 100644 index 0000000..0ce3cb1 --- /dev/null +++ b/utils/vis.py @@ -0,0 +1,149 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +import os +import os.path as osp +import cv2 +import numpy as np +import matplotlib +matplotlib.use('agg') +from mpl_toolkits.mplot3d import Axes3D +import matplotlib.pyplot as plt +import matplotlib as mpl +from config import cfg +from PIL import Image, ImageDraw + +def get_keypoint_rgb(skeleton): + rgb_dict= {} + for joint_id in range(len(skeleton)): + joint_name = skeleton[joint_id]['name'] + + if joint_name.endswith('thumb_null'): + rgb_dict[joint_name] = (255, 0, 0) + elif joint_name.endswith('thumb3'): + rgb_dict[joint_name] = (255, 51, 51) + elif joint_name.endswith('thumb2'): + rgb_dict[joint_name] = (255, 102, 102) + elif joint_name.endswith('thumb1'): + rgb_dict[joint_name] = (255, 153, 153) + elif joint_name.endswith('thumb0'): + rgb_dict[joint_name] = (255, 204, 204) + elif joint_name.endswith('index_null'): + rgb_dict[joint_name] = (0, 255, 0) + elif joint_name.endswith('index3'): + rgb_dict[joint_name] = (51, 255, 51) + elif joint_name.endswith('index2'): + rgb_dict[joint_name] = (102, 255, 102) + elif joint_name.endswith('index1'): + rgb_dict[joint_name] = (153, 255, 153) + elif joint_name.endswith('middle_null'): + rgb_dict[joint_name] = (255, 128, 0) + elif joint_name.endswith('middle3'): + rgb_dict[joint_name] = (255, 153, 51) + elif joint_name.endswith('middle2'): + rgb_dict[joint_name] = (255, 178, 102) + elif joint_name.endswith('middle1'): + rgb_dict[joint_name] = (255, 204, 153) + elif joint_name.endswith('ring_null'): + rgb_dict[joint_name] = (0, 128, 255) + elif joint_name.endswith('ring3'): + rgb_dict[joint_name] = (51, 153, 255) + elif joint_name.endswith('ring2'): + rgb_dict[joint_name] = (102, 178, 255) + elif joint_name.endswith('ring1'): + rgb_dict[joint_name] = (153, 204, 255) + elif joint_name.endswith('pinky_null'): + rgb_dict[joint_name] = (255, 0, 255) + elif joint_name.endswith('pinky3'): + rgb_dict[joint_name] = (255, 51, 255) + elif joint_name.endswith('pinky2'): + rgb_dict[joint_name] = (255, 102, 255) + elif joint_name.endswith('pinky1'): + rgb_dict[joint_name] = (255, 153, 255) + else: + rgb_dict[joint_name] = (230, 230, 0) + + return rgb_dict + + +def vis_keypoints(img, kps, kps_gt, bbox, score, skeleton, filename, score_thr=0.4, line_width=3, circle_rad = 3, save_path=None): + + rgb_dict = get_keypoint_rgb(skeleton) + _img = Image.fromarray(img.transpose(1,2,0).astype('uint8')) + draw = ImageDraw.Draw(_img) + # for i in range(len(skeleton)): + for i in range(len(kps_gt)): + joint_name = skeleton[i]['name'] + pid = skeleton[i]['parent_id'] + parent_joint_name = skeleton[pid]['name'] + + kps_i = (kps[i][0].astype(np.int32), kps[i][1].astype(np.int32)) + kps_pid = (kps[pid][0].astype(np.int32), kps[pid][1].astype(np.int32)) + + if score[i] > score_thr and score[pid] > score_thr and pid != -1: + # draw.line([(kps[i][0], kps[i][1]), (kps[pid][0], kps[pid][1])], fill=rgb_dict[parent_joint_name], width=line_width) + draw.line([(kps[i][0], kps[i][1]), (kps[pid][0], kps[pid][1])], fill=(255, 10, 215), width=line_width) + draw.line([(kps_gt[i][0], kps_gt[i][1]), (kps_gt[pid][0], kps_gt[pid][1])], fill=(173, 230, 216), width=line_width) + + if score[i] > score_thr: + # draw.ellipse((kps[i][0]-circle_rad, kps[i][1]-circle_rad, kps[i][0]+circle_rad, kps[i][1]+circle_rad), fill=rgb_dict[joint_name]) + draw.ellipse((kps[i][0]-circle_rad, kps[i][1]-circle_rad, kps[i][0]+circle_rad, kps[i][1]+circle_rad), fill=(255, 10, 215)) + draw.ellipse((kps_gt[i][0]-circle_rad, kps_gt[i][1]-circle_rad, kps_gt[i][0]+circle_rad, kps_gt[i][1]+circle_rad), fill=(173, 230, 216)) + + if score[pid] > score_thr and pid != -1: + # draw.ellipse((kps[pid][0]-circle_rad, kps[pid][1]-circle_rad, kps[pid][0]+circle_rad, kps[pid][1]+circle_rad), fill=rgb_dict[parent_joint_name]) + draw.ellipse((kps[pid][0]-circle_rad, kps[pid][1]-circle_rad, kps[pid][0]+circle_rad, kps[pid][1]+circle_rad), fill=(255, 10, 215)) + draw.ellipse((kps_gt[pid][0]-circle_rad, kps_gt[pid][1]-circle_rad, kps_gt[pid][0]+circle_rad, kps_gt[pid][1]+circle_rad), fill=(173, 230, 216)) + + draw.line([(bbox[0], bbox[1]), (bbox[0], bbox[1] + bbox[3])], fill = (153, 204, 255), width = 1) + draw.line([(bbox[0], bbox[1]), (bbox[0] + bbox[2], bbox[1])], fill = (153, 204, 255), width = 1) + draw.line([(bbox[0] + bbox[2], bbox[1]), (bbox[0] + bbox[2], bbox[1]+ bbox[3])], fill = (153, 204, 255), width = 1) + draw.line([(bbox[0], bbox[1] + bbox[3]), (bbox[0] + bbox[2], bbox[1]+ bbox[3])], fill = (153, 204, 255), width = 1) + + + if save_path is None: + _img.save(osp.join(cfg.vis_2d_dir, filename)) + else: + _img.save(osp.join(save_path, filename)) + + plt.close() + + +def vis_3d_keypoints(kps_3d, score, skeleton, filename, score_thr=0.4, line_width=3, circle_rad=3): + + fig = plt.figure() + ax = fig.add_subplot(111, projection='3d') + rgb_dict = get_keypoint_rgb(skeleton) + + for i in range(len(skeleton)): + joint_name = skeleton[i]['name'] + pid = skeleton[i]['parent_id'] + parent_joint_name = skeleton[pid]['name'] + + x = np.array([kps_3d[i,0], kps_3d[pid,0]]) + y = np.array([kps_3d[i,1], kps_3d[pid,1]]) + z = np.array([kps_3d[i,2], kps_3d[pid,2]]) + + if score[i] > score_thr and score[pid] > score_thr and pid != -1: + ax.plot(x, z, -y, c = np.array(rgb_dict[parent_joint_name])/255., linewidth = line_width) + # ax.plot(x, -y, z, c = np.array(rgb_dict[parent_joint_name])/255., linewidth = line_width) + + if score[i] > score_thr: + ax.scatter(kps_3d[i,0], kps_3d[i,2], -kps_3d[i,1], c = np.array(rgb_dict[joint_name]).reshape(1,3)/255., marker='o') + if score[pid] > score_thr and pid != -1: + ax.scatter(kps_3d[pid,0], kps_3d[pid,2], -kps_3d[pid,1], c = np.array(rgb_dict[parent_joint_name]).reshape(1,3)/255., marker='o') + + ax.set(xlim=[-200, 50], ylim=[-200, 50], zlim=[-100, 100]) + ax.view_init(5, 20) + + fig.savefig(osp.join(cfg.vis_3d_dir, filename), dpi=fig.dpi) + + + + plt.close() + +