forked from open-mmlab/mmdetection3d
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Feature] Support entire PAConv and PAConvCUDA models (open-mmlab#783)
* add PAConv decode head * add config files * add paconv's correlation loss * support reg loss in Segmentor class * minor fix * add augmentation to configs * fix ed7 in cfg * fix bug in corr loss * enable syncbn in paconv * rename to loss_regularization * rename loss_reg to loss_regularize * use SyncBN * change weight kernels to kernel weights * rename corr_loss to reg_loss * minor fix * configs fix IndoorPatchPointSample * fix grouped points minus center error * update transform_3d & add configs * merge master * fix enlarge_size bug * refine config * remove cfg files * minor fix * add comments on PAConv's ScoreNet * refine comments * update compatibility doc * remove useless lines in transforms_3d * rename with_loss_regularization to with_regularization_loss * revert palette change * remove xavier init from PAConv's ScoreNet
- Loading branch information
Showing
21 changed files
with
664 additions
and
81 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
_base_ = './paconv_ssg.py' | ||
|
||
model = dict( | ||
backbone=dict( | ||
sa_cfg=dict( | ||
type='PAConvCUDASAModule', | ||
scorenet_cfg=dict(mlp_channels=[8, 16, 16])))) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
# model settings | ||
model = dict( | ||
type='EncoderDecoder3D', | ||
backbone=dict( | ||
type='PointNet2SASSG', | ||
in_channels=9, # [xyz, rgb, normalized_xyz] | ||
num_points=(1024, 256, 64, 16), | ||
radius=(None, None, None, None), # use kNN instead of ball query | ||
num_samples=(32, 32, 32, 32), | ||
sa_channels=((32, 32, 64), (64, 64, 128), (128, 128, 256), (256, 256, | ||
512)), | ||
fp_channels=(), | ||
norm_cfg=dict(type='BN2d', momentum=0.1), | ||
sa_cfg=dict( | ||
type='PAConvSAModule', | ||
pool_mod='max', | ||
use_xyz=True, | ||
normalize_xyz=False, | ||
paconv_num_kernels=[16, 16, 16], | ||
paconv_kernel_input='w_neighbor', | ||
scorenet_input='w_neighbor_dist', | ||
scorenet_cfg=dict( | ||
mlp_channels=[16, 16, 16], | ||
score_norm='softmax', | ||
temp_factor=1.0, | ||
last_bn=False))), | ||
decode_head=dict( | ||
type='PAConvHead', | ||
# PAConv model's decoder takes skip connections from beckbone | ||
# different from PointNet++, it also concats input features in the last | ||
# level of decoder, leading to `128 + 6` as the channel number | ||
fp_channels=((768, 256, 256), (384, 256, 256), (320, 256, 128), | ||
(128 + 6, 128, 128, 128)), | ||
channels=128, | ||
dropout_ratio=0.5, | ||
conv_cfg=dict(type='Conv1d'), | ||
norm_cfg=dict(type='BN1d'), | ||
act_cfg=dict(type='ReLU'), | ||
loss_decode=dict( | ||
type='CrossEntropyLoss', | ||
use_sigmoid=False, | ||
class_weight=None, # should be modified with dataset | ||
loss_weight=1.0)), | ||
# correlation loss to regularize PAConv's kernel weights | ||
loss_regularization=dict( | ||
type='PAConvRegularizationLoss', reduction='sum', loss_weight=10.0), | ||
# model training and testing settings | ||
train_cfg=dict(), | ||
test_cfg=dict(mode='slide')) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
from .paconv_head import PAConvHead | ||
from .pointnet2_head import PointNet2Head | ||
|
||
__all__ = ['PointNet2Head'] | ||
__all__ = ['PointNet2Head', 'PAConvHead'] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
from mmcv.cnn.bricks import ConvModule | ||
|
||
from mmdet.models import HEADS | ||
from .pointnet2_head import PointNet2Head | ||
|
||
|
||
@HEADS.register_module() | ||
class PAConvHead(PointNet2Head): | ||
r"""PAConv decoder head. | ||
Decoder head used in `PAConv <https://arxiv.org/abs/2103.14635>`_. | ||
Refer to the `official code <https://github.com/CVMI-Lab/PAConv>`_. | ||
Args: | ||
fp_channels (tuple[tuple[int]]): Tuple of mlp channels in FP modules. | ||
fp_norm_cfg (dict|None): Config of norm layers used in FP modules. | ||
Default: dict(type='BN2d'). | ||
""" | ||
|
||
def __init__(self, | ||
fp_channels=((768, 256, 256), (384, 256, 256), | ||
(320, 256, 128), (128 + 6, 128, 128, 128)), | ||
fp_norm_cfg=dict(type='BN2d'), | ||
**kwargs): | ||
super(PAConvHead, self).__init__(fp_channels, fp_norm_cfg, **kwargs) | ||
|
||
# https://github.com/CVMI-Lab/PAConv/blob/main/scene_seg/model/pointnet2/pointnet2_paconv_seg.py#L53 | ||
# PointNet++'s decoder conv has bias while PAConv's doesn't have | ||
# so we need to rebuild it here | ||
self.pre_seg_conv = ConvModule( | ||
fp_channels[-1][-1], | ||
self.channels, | ||
kernel_size=1, | ||
bias=False, | ||
conv_cfg=self.conv_cfg, | ||
norm_cfg=self.norm_cfg, | ||
act_cfg=self.act_cfg) | ||
|
||
def forward(self, feat_dict): | ||
"""Forward pass. | ||
Args: | ||
feat_dict (dict): Feature dict from backbone. | ||
Returns: | ||
torch.Tensor: Segmentation map of shape [B, num_classes, N]. | ||
""" | ||
sa_xyz, sa_features = self._extract_input(feat_dict) | ||
|
||
# PointNet++ doesn't use the first level of `sa_features` as input | ||
# while PAConv inputs it through skip-connection | ||
fp_feature = sa_features[-1] | ||
|
||
for i in range(self.num_fp): | ||
# consume the points in a bottom-up manner | ||
fp_feature = self.FP_modules[i](sa_xyz[-(i + 2)], sa_xyz[-(i + 1)], | ||
sa_features[-(i + 2)], fp_feature) | ||
|
||
output = self.pre_seg_conv(fp_feature) | ||
output = self.cls_seg(output) | ||
|
||
return output |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,8 +1,10 @@ | ||
from mmdet.models.losses import FocalLoss, SmoothL1Loss, binary_cross_entropy | ||
from .axis_aligned_iou_loss import AxisAlignedIoULoss, axis_aligned_iou_loss | ||
from .chamfer_distance import ChamferDistance, chamfer_distance | ||
from .paconv_regularization_loss import PAConvRegularizationLoss | ||
|
||
__all__ = [ | ||
'FocalLoss', 'SmoothL1Loss', 'binary_cross_entropy', 'ChamferDistance', | ||
'chamfer_distance', 'axis_aligned_iou_loss', 'AxisAlignedIoULoss' | ||
'chamfer_distance', 'axis_aligned_iou_loss', 'AxisAlignedIoULoss', | ||
'PAConvRegularizationLoss' | ||
] |
Oops, something went wrong.