Skip to content
This repository was archived by the owner on Jul 19, 2019. It is now read-only.

Commit f001385

Browse files
author
ycszen
committed
add DUC, GCN
1 parent 80c8619 commit f001385

File tree

6 files changed

+580
-0
lines changed

6 files changed

+580
-0
lines changed

Diff for: README.md

+8
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
# Pytorch for Semantic Segmentation
2+
A repository contains some exiting networks and some experimental networks for semantic segmentation.
3+
+ [x] ResNet(FCN)
4+
- [x] ResNet-50
5+
- [x] ResNet-101
6+
- [ ] Wide-ResNet
7+
+ [x] DUC(*Understanding Convolution for Semantic Segmentation*)[pdf](https://arxiv.org/abs/1702.08502)
8+
+ [x] GCN(*Large Kernel Matters -- Improve Semantic Segmentation by Global Convolutional Network*)[pdf](https://arxiv.org/abs/1703.02719)

Diff for: duc.py

+104
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
import torch
2+
import torch.nn as nn
3+
import torch.nn.functional as F
4+
import torch.nn.init as init
5+
import torch.utils.model_zoo as model_zoo
6+
from torchvision import models
7+
8+
import math
9+
10+
11+
class DUC(nn.Module):
12+
def __init__(self, inplanes, planes, upscale_factor=2):
13+
super(DUC, self).__init__()
14+
self.relu = nn.ReLU()
15+
self.conv = nn.Conv2d(inplanes, planes, kernel_size=3,
16+
padding=1)
17+
self.bn = nn.BatchNorm2d(planes)
18+
self.pixel_shuffle = nn.PixelShuffle(upscale_factor)
19+
20+
def forward(self, x):
21+
x = self.conv(x)
22+
x = self.bn(x)
23+
x = self.relu(x)
24+
x = self.pixel_shuffle(x)
25+
return x
26+
27+
class FCN(nn.Module):
28+
def __init__(self, num_classes):
29+
super(FCN, self).__init__()
30+
31+
self.num_classes = num_classes
32+
33+
resnet = models.resnet50(pretrained=True)
34+
35+
self.conv1 = resnet.conv1
36+
self.bn0 = resnet.bn1
37+
self.relu = resnet.relu
38+
self.maxpool = resnet.maxpool
39+
40+
self.layer1 = resnet.layer1
41+
self.layer2 = resnet.layer2
42+
self.layer3 = resnet.layer3
43+
self.layer4 = resnet.layer4
44+
45+
self.duc1 = DUC(2048, 2048*2)
46+
self.duc2 = DUC(1024, 1024*2)
47+
self.duc3 = DUC(512, 512*2)
48+
self.duc4 = DUC(128, 128*2)
49+
self.duc5 = DUC(64, 64*2)
50+
51+
self.out1 = self._classifier(1024)
52+
self.out2 = self._classifier(512)
53+
self.out3 = self._classifier(128)
54+
self.out4 = self._classifier(64)
55+
self.out5 = self._classifier(32)
56+
57+
self.transformer = nn.Conv2d(320, 128, kernel_size=1)
58+
59+
def _classifier(self, inplanes):
60+
if inplanes == 32:
61+
return nn.Sequential(
62+
nn.Conv2d(inplanes, self.num_classes, 1),
63+
nn.Conv2d(self.num_classes, self.num_classes,
64+
kernel_size=3, padding=1)
65+
)
66+
return nn.Sequential(
67+
nn.Conv2d(inplanes, inplanes/2, 3, padding=1, bias=False),
68+
nn.BatchNorm2d(inplanes/2, momentum=.95),
69+
nn.ReLU(inplace=True),
70+
nn.Dropout(.1),
71+
nn.Conv2d(inplanes/2, self.num_classes, 1),
72+
)
73+
74+
def forward(self, x):
75+
x = self.conv1(x)
76+
x = self.bn0(x)
77+
x = self.relu(x)
78+
conv_x = x
79+
x = self.maxpool(x)
80+
pool_x = x
81+
82+
fm1 = self.layer1(x)
83+
fm2 = self.layer2(fm1)
84+
fm3 = self.layer3(fm2)
85+
fm4 = self.layer4(fm3)
86+
87+
dfm1 = fm3 + self.duc1(fm4)
88+
out16 = self.out1(dfm1)
89+
90+
dfm2 = fm2 + self.duc2(dfm1)
91+
out8 = self.out2(dfm2)
92+
93+
dfm3 = fm1 + self.duc3(dfm2)
94+
95+
dfm3_t = self.transformer(torch.cat((dfm3, pool_x), 1))
96+
out4 = self.out3(dfm3_t)
97+
98+
dfm4 = conv_x + self.duc4(dfm3_t)
99+
out2 = self.out4(dfm4)
100+
101+
dfm5 = self.duc5(dfm4)
102+
out = self.out5(dfm5)
103+
104+
return out, out2, out4, out8, out16

Diff for: gcn.py

+137
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
import torch
2+
import torch.nn as nn
3+
import torch.nn.functional as F
4+
import torch.nn.init as init
5+
import torch.utils.model_zoo as model_zoo
6+
from torchvision import models
7+
8+
import math
9+
10+
11+
class GCN(nn.Module):
12+
def __init__(self, inplanes, planes, ks=7):
13+
super(GCN, self).__init__()
14+
self.conv_l1 = nn.Conv2d(inplanes, planes, kernel_size=(ks, 1),
15+
padding=(ks/2, 0))
16+
17+
self.conv_l2 = nn.Conv2d(planes, planes, kernel_size=(1, ks),
18+
padding=(0, ks/2))
19+
self.conv_r1 = nn.Conv2d(inplanes, planes, kernel_size=(1, ks),
20+
padding=(0, ks/2))
21+
self.conv_r2 = nn.Conv2d(planes, planes, kernel_size=(ks, 1),
22+
padding=(ks/2, 0))
23+
24+
def forward(self, x):
25+
x_l = self.conv_l1(x)
26+
x_l = self.conv_l2(x_l)
27+
28+
x_r = self.conv_r1(x)
29+
x_r = self.conv_r2(x_r)
30+
31+
x = x_l + x_r
32+
33+
return x
34+
35+
36+
class Refine(nn.Module):
37+
def __init__(self, planes):
38+
super(Refine, self).__init__()
39+
self.bn = nn.BatchNorm2d(planes)
40+
self.relu = nn.ReLU(inplace=True)
41+
self.conv1 = nn.Conv2d(planes, planes, kernel_size=3, padding=1)
42+
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1)
43+
44+
def forward(self, x):
45+
residual = x
46+
x = self.bn(x)
47+
x = self.relu(x)
48+
x = self.conv1(x)
49+
x = self.bn(x)
50+
x = self.relu(x)
51+
x = self.conv2(x)
52+
53+
out = residual + x
54+
return out
55+
56+
57+
class FCN(nn.Module):
58+
def __init__(self, num_classes):
59+
super(FCN, self).__init__()
60+
61+
self.num_classes = num_classes
62+
63+
resnet = models.resnet50(pretrained=True)
64+
65+
self.conv1 = resnet.conv1
66+
self.bn0 = resnet.bn1
67+
self.relu = resnet.relu
68+
self.maxpool = resnet.maxpool
69+
70+
self.layer1 = resnet.layer1
71+
self.layer2 = resnet.layer2
72+
self.layer3 = resnet.layer3
73+
self.layer4 = resnet.layer4
74+
75+
self.gcn1 = GCN(2048, self.num_classes)
76+
self.gcn2 = GCN(1024, self.num_classes)
77+
self.gcn3 = GCN(512, self.num_classes)
78+
self.gcn4 = GCN(64, self.num_classes)
79+
self.gcn5 = GCN(64, self.num_classes)
80+
81+
self.refine1 = Refine(self.num_classes)
82+
self.refine2 = Refine(self.num_classes)
83+
self.refine3 = Refine(self.num_classes)
84+
self.refine4 = Refine(self.num_classes)
85+
self.refine5 = Refine(self.num_classes)
86+
self.refine6 = Refine(self.num_classes)
87+
self.refine7 = Refine(self.num_classes)
88+
self.refine8 = Refine(self.num_classes)
89+
self.refine9 = Refine(self.num_classes)
90+
self.refine10 = Refine(self.num_classes)
91+
92+
self.out0 = self._classifier(2048)
93+
self.out1 = self._classifier(1024)
94+
self.out2 = self._classifier(512)
95+
self.out_e = self._classifier(256)
96+
self.out3 = self._classifier(64)
97+
self.out4 = self._classifier(64)
98+
self.out5 = self._classifier(32)
99+
100+
self.transformer = nn.Conv2d(256, 64, kernel_size=1)
101+
102+
def _classifier(self, inplanes):
103+
return nn.Sequential(
104+
nn.Conv2d(inplanes, inplanes, 3, padding=1, bias=False),
105+
nn.BatchNorm2d(inplanes/2),
106+
nn.ReLU(inplace=True),
107+
nn.Dropout(.1),
108+
nn.Conv2d(inplanes/2, self.num_classes, 1),
109+
)
110+
111+
def forward(self, x):
112+
input = x
113+
x = self.conv1(x)
114+
x = self.bn0(x)
115+
x = self.relu(x)
116+
conv_x = x
117+
x = self.maxpool(x)
118+
pool_x = x
119+
120+
fm1 = self.layer1(x)
121+
fm2 = self.layer2(fm1)
122+
fm3 = self.layer3(fm2)
123+
fm4 = self.layer4(fm3)
124+
125+
gcfm1 = self.refine1(self.gcn1(fm4))
126+
gcfm2 = self.refine2(self.gcn2(fm3))
127+
gcfm3 = self.refine3(self.gcn3(fm2))
128+
gcfm4 = self.refine4(self.gcn4(pool_x))
129+
gcfm5 = self.refine5(self.gcn5(conv_x))
130+
131+
fs1 = self.refine6(F.upsample_bilinear(gcfm1, fm3.size()[2:]) + gcfm2)
132+
fs2 = self.refine7(F.upsample_bilinear(fs1, fm2.size()[2:]) + gcfm3)
133+
fs3 = self.refine8(F.upsample_bilinear(fs2, pool_x.size()[2:]) + gcfm4)
134+
fs4 = self.refine9(F.upsample_bilinear(fs3, conv_x.size()[2:]) + gcfm5)
135+
out = self.refine10(F.upsample_bilinear(fs4, input.size()[2:]))
136+
137+
return out, fs4, fs3, fs2, fs1, gcfm1

Diff for: tester.py

+64
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
import torch
2+
from torch.utils import data
3+
import torch.optim as optim
4+
from torch.autograd import Variable
5+
from transform import Colorize
6+
from torchvision.transforms import ToPILImage, Compose, ToTensor, CenterCrop
7+
from transform import Scale
8+
# from resnet import FCN
9+
from upsample import FCN
10+
# from gcn import FCN
11+
from datasets import VOCTestSet
12+
from PIL import Image
13+
import numpy as np
14+
from tqdm import tqdm
15+
16+
17+
label_transform = Compose([Scale((256, 256), Image.BILINEAR), ToTensor()])
18+
batch_size = 1
19+
dst = VOCTestSet("./data", transform=label_transform)
20+
21+
testloader = data.DataLoader(dst, batch_size=batch_size,
22+
num_workers=8)
23+
24+
25+
model = torch.nn.DataParallel(FCN(22), device_ids=[0, 1, 2, 3])
26+
# model = FCN(22)
27+
model.cuda()
28+
model.load_state_dict(torch.load("./pth/fcn-deconv-40.pth"))
29+
model.eval()
30+
31+
32+
# 10 13 48 86 101
33+
img = Image.open("./data/VOC2012test/JPEGImages/2008_000101.jpg").convert("RGB")
34+
original_size = img.size
35+
img.save("original.png")
36+
img = img.resize((256, 256), Image.BILINEAR)
37+
img = ToTensor()(img)
38+
img = Variable(img).unsqueeze(0)
39+
outputs = model(img)
40+
# 22 256 256
41+
for i, output in enumerate(outputs):
42+
output = output[0].data.max(0)[1]
43+
output = Colorize()(output)
44+
output = np.transpose(output.numpy(), (1, 2, 0))
45+
img = Image.fromarray(output, "RGB")
46+
if i == 0:
47+
img = img.resize(original_size, Image.NEAREST)
48+
img.save("test-%d.png" % i)
49+
50+
'''
51+
52+
for index, (imgs, name, size) in tqdm(enumerate(testloader)):
53+
imgs = Variable(imgs.cuda())
54+
outputs = model(imgs)
55+
56+
output = outputs[0][0].data.max(0)[1]
57+
output = Colorize()(output)
58+
print(output)
59+
output = np.transpose(output.numpy(), (1, 2, 0))
60+
img = Image.fromarray(output, "RGB")
61+
# img = Image.fromarray(output[0].cpu().numpy(), "P")
62+
img = img.resize((size[0].numpy(), size[1].numpy()), Image.NEAREST)
63+
img.save("./results/VOC2012/Segmentation/comp5_test_cls/%s.png" % name)
64+
'''

0 commit comments

Comments
 (0)