Skip to content

Commit

Permalink
iclr2022
Browse files Browse the repository at this point in the history
  • Loading branch information
Hanqer committed Mar 15, 2022
1 parent 2e2327c commit ce3ada0
Show file tree
Hide file tree
Showing 10 changed files with 393 additions and 29 deletions.
20 changes: 3 additions & 17 deletions config.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
# Label Smoothing
_C.MODEL.LABEL_SMOOTHING = 0.1

# Swin Transformer parameters
# DWNet parameters
_C.MODEL.DWNET = CN()
_C.MODEL.DWNET.PATCH_SIZE = 4
_C.MODEL.DWNET.IN_CHANS = 3
Expand All @@ -62,22 +62,8 @@
_C.MODEL.DWNET.PATCH_NORM = True
_C.MODEL.DWNET.CONV_TYPE = "v1"
_C.MODEL.DWNET.DYNAMIC = False

# halo Transformer parameters
_C.MODEL.HALO = CN()
_C.MODEL.HALO.PATCH_SIZE = 4
_C.MODEL.HALO.IN_CHANS = 3
_C.MODEL.HALO.EMBED_DIM = 96
_C.MODEL.HALO.DEPTHS = [2, 2, 6, 2]
_C.MODEL.HALO.NUM_HEADS = [3, 6, 12, 24]
_C.MODEL.HALO.WINDOW_SIZE = [7, 7, 7, 7]
_C.MODEL.HALO.HALO_SIZE = [3, 3, 3, 3]
_C.MODEL.HALO.MLP_RATIO = 4.
_C.MODEL.HALO.QKV_BIAS = True
_C.MODEL.HALO.QK_SCALE = None
_C.MODEL.HALO.APE = False
_C.MODEL.HALO.PATCH_NORM = True

_C.MODEL.DWNET.INHOMO = False
_C.MODEL.DWNET.INHOMO_HEADS = [ 4, 8, 16, 32 ]

# -----------------------------------------------------------------------------
# Training settings
Expand Down
12 changes: 12 additions & 0 deletions configs/i_dynamic_dwnet_base_patch4_window7_224.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
MODEL:
TYPE: iddwnet
NAME: iddwnet_base_patch4_window7_224
DROP_PATH_RATE: 0.5
DWNET:
EMBED_DIM: 128
DEPTHS: [ 2, 2, 18, 2 ]
WINDOW_SIZE: 7
DYNAMIC: True
INHOMO: True
INHOMO_HEADS: [ 4, 8, 16, 32 ]
AMP_OPT_LEVEL: "O0"
12 changes: 12 additions & 0 deletions configs/i_dynamic_dwnet_tiny_patch4_window7_224.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
MODEL:
TYPE: iddwnet
NAME: iddwnet_tiny_patch4_window7_224
DROP_PATH_RATE: 0.2
DWNET:
EMBED_DIM: 96
DEPTHS: [ 2, 2, 6, 2 ]
WINDOW_SIZE: 7
DYNAMIC: True
INHOMO: True
INHOMO_HEADS: [ 3, 6, 12, 24 ]
AMP_OPT_LEVEL: "O0"
17 changes: 17 additions & 0 deletions models/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,23 @@ def build_model(config):
patch_norm=config.MODEL.DWNET.PATCH_NORM,
use_checkpoint=config.TRAIN.USE_CHECKPOINT,
dynamic=config.MODEL.DWNET.DYNAMIC)
elif model_type == 'iddwnet':
model = DWNet(img_size=config.DATA.IMG_SIZE,
patch_size=config.MODEL.DWNET.PATCH_SIZE,
in_chans=config.MODEL.DWNET.IN_CHANS,
num_classes=config.MODEL.NUM_CLASSES,
embed_dim=config.MODEL.DWNET.EMBED_DIM,
depths=config.MODEL.DWNET.DEPTHS,
window_size=config.MODEL.DWNET.WINDOW_SIZE,
mlp_ratio=config.MODEL.DWNET.MLP_RATIO,
drop_rate=config.MODEL.DROP_RATE,
drop_path_rate=config.MODEL.DROP_PATH_RATE,
ape=config.MODEL.DWNET.APE,
patch_norm=config.MODEL.DWNET.PATCH_NORM,
use_checkpoint=config.TRAIN.USE_CHECKPOINT,
dynamic=config.MODEL.DWNET.DYNAMIC,
inhomogeneous=config.MODEL.DWNET.INHOMO,
inhomo_heads=config.MODEL.DWNET.INHOMO_HEADS)
else:
raise NotImplementedError(f"Unkown model: {model_type}")

Expand Down
33 changes: 23 additions & 10 deletions models/dwnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import torch.utils.checkpoint as checkpoint
import torch.nn.functional as F
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
from .idynamic import IDynamicDWConv


class Mlp(nn.Module):
Expand Down Expand Up @@ -55,18 +56,23 @@ def forward(self, x):

class DWBlock(nn.Module):

def __init__(self, dim, window_size, dynamic=False):
def __init__(self, dim, window_size, dynamic=False, inhomogeneous=False, heads=None):
super().__init__()
self.dim = dim
self.window_size = window_size # Wh, Ww
self.dynamic = dynamic
self.inhomogeneous = inhomogeneous
self.heads = heads

# pw-linear
self.conv0 = nn.Conv2d(dim, dim, 1, bias=False)
self.bn0 = nn.BatchNorm2d(dim)

if dynamic:
if dynamic and not inhomogeneous:
self.conv = DynamicDWConv(dim, kernel_size=window_size, stride=1, padding=window_size // 2, groups=dim)
if dynamic and inhomogeneous:
print(window_size, heads)
self.conv = IDynamicDWConv(dim, window_size, heads)
else :
self.conv = nn.Conv2d(dim, dim, kernel_size=window_size, stride=1, padding=window_size // 2, groups=dim)

Expand All @@ -80,7 +86,7 @@ def __init__(self, dim, window_size, dynamic=False):
def forward(self, x):
B, H, W, C = x.shape

x = x.permute(0, 3, 1, 2)
x = x.permute(0, 3, 1, 2).contiguous()
x = self.conv0(x)
x = self.bn0(x)
x = self.relu(x)
Expand All @@ -92,7 +98,7 @@ def forward(self, x):
x = self.conv2(x)
x=self.bn2(x)

x = x.permute(0, 2, 3, 1)
x = x.permute(0, 2, 3, 1).contiguous()
return x

def extra_repr(self) -> str:
Expand All @@ -104,8 +110,10 @@ def flops(self, N):
# x = self.conv0(x)
flops += N * self.dim * self.dim
# x = self.conv(x)
if self.dynamic:
if self.dynamic and not self.inhomogeneous:
flops += (N * self.dim + self.dim * self.dim / 4 + self.dim / 4 * self.dim * self.window_size * self.window_size)
elif self.dynamic and self.inhomogeneous:
flops += (N * self.dim * self.dim / 4 + N * self.dim / 4 * self.dim / self.heads * self.window_size * self.window_size)
flops += N * self.dim * self.window_size * self.window_size
# x = self.conv2(x)
flops += N * self.dim * self.dim
Expand All @@ -117,15 +125,15 @@ def flops(self, N):
class SpatialBlock(nn.Module):

def __init__(self, dim, input_resolution, window_size=7,
mlp_ratio=4., drop=0., drop_path=0., dynamic=False, act_layer=nn.GELU):
mlp_ratio=4., drop=0., drop_path=0., dynamic=False, inhomogeneous=False, inhomo_head=None, act_layer=nn.GELU):
super().__init__()
self.dim = dim
self.input_resolution = input_resolution
self.window_size = window_size
self.mlp_ratio = mlp_ratio
self.dynamic = dynamic

self.attn2conv = DWBlock(dim, window_size, dynamic)
self.attn2conv = DWBlock(dim, window_size, dynamic, inhomogeneous, inhomo_head)

self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()

Expand Down Expand Up @@ -211,7 +219,7 @@ class BasicLayer(nn.Module):

def __init__(self, dim, input_resolution, depth, window_size,
mlp_ratio=4., drop=0., drop_path=0., norm_layer=nn.LayerNorm,
downsample=None, use_checkpoint=False, dynamic=False):
downsample=None, use_checkpoint=False, dynamic=False, inhomogeneous=False, inhomo_head=None):

super().__init__()
self.dim = dim
Expand All @@ -226,6 +234,8 @@ def __init__(self, dim, input_resolution, depth, window_size,
mlp_ratio=mlp_ratio,
drop=drop,
dynamic=dynamic,
inhomogeneous=inhomogeneous,
inhomo_head=inhomo_head,
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path)
for i in range(depth)])

Expand Down Expand Up @@ -301,7 +311,7 @@ class DWNet(nn.Module):
def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000,
embed_dim=96, depths=[2, 2, 6, 2], window_size=7, mlp_ratio=4.,
drop_rate=0., drop_path_rate=0.1, norm_layer=nn.LayerNorm,
ape=False, patch_norm=True, use_checkpoint=False, dynamic=False, **kwargs):
ape=False, patch_norm=True, use_checkpoint=False, dynamic=False, inhomogeneous=False, inhomo_heads=None, **kwargs):
super().__init__()

self.num_classes = num_classes
Expand Down Expand Up @@ -344,7 +354,10 @@ def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000,
norm_layer=norm_layer,
downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
use_checkpoint=use_checkpoint,
dynamic=dynamic)
dynamic=dynamic,
inhomogeneous=inhomogeneous,
inhomo_head=inhomo_heads[i_layer])

self.layers.append(layer)

self.norm = norm_layer(self.num_features)
Expand Down
Loading

0 comments on commit ce3ada0

Please sign in to comment.