-
Notifications
You must be signed in to change notification settings - Fork 0
/
dispnet_c.py
125 lines (119 loc) · 5.85 KB
/
dispnet_c.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import tensorflow as tf
import numpy as np
from tensorflow.keras import layers, Sequential
from correlation import correlation
class DispNet(tf.keras.Model):
def __init__(self):
super(DispNet, self).__init__()
self.conv1a = layers.Conv2D(64,(7,7),padding = 'same',activation = 'relu')
self.pool1a = layers.MaxPool2D((2,2))
self.conv3a = layers.Conv2D(128,(5,5),padding = 'same',activation = 'relu')
self.pool3a = layers.MaxPool2D((2,2))
self.conv17a = layers.Conv2D(256,(5,5),padding = 'same',activation = 'relu')
self.pool8a = layers.MaxPool2D((2,2))
self.conv1b = layers.Conv2D(64,(7,7),padding = 'same',activation = 'relu')
self.pool1b = layers.MaxPool2D((2,2))
self.conv3b = layers.Conv2D(128,(5,5),padding = 'same',activation = 'relu')
self.pool3b = layers.MaxPool2D((2,2))
self.conv17b = layers.Conv2D(256,(5,5),padding = 'same',activation = 'relu')
self.pool8b = layers.MaxPool2D((2,2))
self.corr = layers.DepthwiseConv2D(kernel_size = (1,1), strides=(1,1), padding='valid', depth_multiplier=1)
self.conva = layers.Conv2D(32,(1,1),padding = 'same',activation = 'relu')
self.conv4 = layers.Conv2D(256,(3,3),padding = 'same',activation = 'relu')
self.conv9 = layers.Conv2D(512,(3,3),padding = 'same',activation = 'relu')
self.pool5 = layers.MaxPool2D((2,2))
self.conv10 = layers.Conv2D(512,(3,3),padding = 'same',activation = 'relu')
self.conv11 = layers.Conv2D(512,(3,3),padding = 'same',activation = 'relu')
self.pool6 = layers.MaxPool2D((2,2))
self.conv12 = layers.Conv2D(512,(3,3),padding = 'same',activation = 'relu')
self.conv13 = layers.Conv2D(1024,(3,3),padding = 'same',activation = 'relu')
self.pool7 = layers.MaxPool2D((2,2))
self.conv14 = layers.Conv2D(1024,(3,3),padding = 'same',activation = 'relu')
self.conv18 = layers.Conv2D(1,(3,3),padding = 'same',activation = 'relu')
self.up1 = layers.UpSampling2D((2,2))
self.deconv4 = layers.Conv2DTranspose(512,(4,4),strides=(2, 2),padding = 'same',activation = 'relu')
self.bn1 = layers.BatchNormalization()
self.conv19 = layers.Conv2D(512,(3,3),padding = 'same',activation = 'relu')
self.conv20 = layers.Conv2D(1,(3,3),padding = 'same',activation = 'relu')
self.up2 = layers.UpSampling2D((2,2))
self.deconv5 = layers.Conv2DTranspose(256,(4,4),strides=(2, 2),padding = 'same',activation = 'relu')
self.bn2 = layers.BatchNormalization()
self.conv21 = layers.Conv2D(256,(3,3),padding = 'same',activation = 'relu')
self.deconv24 = layers.Conv2DTranspose(128,(4,4),strides=(2, 2),padding = 'same',activation = 'relu')
self.bn3 = layers.BatchNormalization()
self.conv22 = layers.Conv2D(1,(3,3),padding = 'same',activation = 'relu')
self.up3 = layers.UpSampling2D((2,2))
self.conv23 = layers.Conv2D(128,(3,3),padding = 'same',activation = 'relu')
self.conv24 = layers.Conv2D(1,(3,3),padding = 'same',activation = 'relu')
self.up4 = layers.UpSampling2D((2,2))
self.deconv7 = layers.Conv2DTranspose(64,(4,4),strides=(2, 2),padding = 'same',activation = 'relu')
self.bn4 = layers.BatchNormalization()
self.conv25 = layers.Conv2D(64,(3,3),padding = 'same',activation = 'relu')
self.conv26 = layers.Conv2D(1,(3,3),padding = 'same',activation = 'relu')
self.up5 = layers.UpSampling2D((2,2))
self.deconv8 = layers.Conv2DTranspose(32,(4,4),strides=(2, 2),padding = 'same',activation = 'relu')
self.bn5 = layers.BatchNormalization()
self.conv27 = layers.Conv2D(32,(3,3),padding = 'same',activation = 'relu')
self.conv28 = layers.Conv2D(1,(3,3),padding = 'same',activation = 'relu')
def call(self, left, right, training=None):
c1a = self.conv1a(left)
p1a = self.pool1a(c1a)
c3a = self.conv3a(p1a)
p3a = self.pool3a(c3a)
c17a = self.conv17a(p3a)
p8a = self.pool8a(c17a)
c1b = self.conv1b(right)
p1b = self.pool1b(c1b)
c3b = self.conv3b(p1b)
p3b = self.pool3b(c3b)
c17b = self.conv17b(p3b)
p8b = self.pool8b(c17b)
# c = tf.concat([p8a, p8b],axis = 3)
# cc = self.corr(c)
cc = correlation(p8a, p8b)
cc = tf.nn.leaky_relu(cc, 0.1)
ca = self.conva(p8a)
net = tf.concat([ca,cc],axis = 3)
c4 = self.conv4(net)
c9 = self.conv9(c4)
p5 = self.pool5(c9)
c10 = self.conv10(p5)
c11 = self.conv11(c10)
p6 = self.pool6(c11)
c12 = self.conv12(p6)
c13 = self.conv13(c12)
p7 = self.pool7(c13)
c14 = self.conv14(p7)
c18 = self.conv18(c14)
u1 = self.up1(c18)
d4 = self.deconv4(c14)
b1 = self.bn1(d4)
merge_2 = tf.concat([c12,b1,u1],axis = 3)
c19 = self.conv19(merge_2)
c20 = self.conv20(c19)
u2 = self.up2(c20)
d5 = self.deconv5(c19)
b2 = self.bn2(d5)
merge_3 = tf.concat([c10,b2,u2],axis = 3)
c21 = self.conv21(merge_3)
d24 = self.deconv24(c21)
b3 = self.bn3(d24)
c22 = self.conv22(c21)
u3 = self.up3(c22)
merge_4 = tf.concat([c4,b3,u3],axis = 3)
c23 = self.conv23(merge_4)
c24 = self.conv24(c23)
u4 = self.up4(c24)
d7 = self.deconv7(c23)
b4 = self.bn4(d7)
# print(p3a.shape,b4.shape,u4.shape)
merge_5 = layers.concatenate([p3a,b4,u4],axis = 3)#([p3b,b4,u4],axis = 3)
c25 = self.conv25(merge_5)
c26 = self.conv26(c25)
u5 = self.up5(c26)
d8 = self.deconv8(c25)
b5 = self.bn5(d8)
merge_6 = tf.concat([p1a,b5,u5],axis = 3)#([p1b,b5,u5],axis = 3)
c27 = self.conv27(merge_6)
out = self.conv28(c27)
return out,c26,c24,c22,c20