-
Notifications
You must be signed in to change notification settings - Fork 3
/
unetV2.py
119 lines (89 loc) · 4.83 KB
/
unetV2.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
# -*- coding: utf-8 -*-
"""
Created on Mon Feb 5 10:36:33 2018
https://github.com/jocicmarko/ultrasound-nerve-segmentation/blob/master/train.py
@author: IbrahimD
"""
from keras.models import Model
from keras.layers import Input, concatenate, Conv2D, MaxPooling2D, Conv2DTranspose, UpSampling2D
from keras.optimizers import Adam
from keras.callbacks import ModelCheckpoint
from keras.layers.normalization import BatchNormalization
from keras import backend as K
from keras.layers.core import Activation, SpatialDropout2D
def block(prevlayer, a, b, pooling):
conva = Conv2D(a, (3, 3), activation='relu', padding='same')(prevlayer)
conva = BatchNormalization()(conva)
conva = Conv2D(b, (3, 3), activation='relu', padding='same')(conva)
conva = BatchNormalization()(conva)
if True == pooling:
conva = MaxPooling2D(pool_size=(2, 2))(conva)
convb = Conv2D(a, (5, 5), activation='relu', padding='same')(prevlayer)
convb = BatchNormalization()(convb)
convb = Conv2D(b, (5, 5), activation='relu', padding='same')(convb)
convb = BatchNormalization()(convb)
if True == pooling:
convb = MaxPooling2D(pool_size=(2, 2))(convb)
convc = Conv2D(b, (1, 1), activation='relu', padding='same')(prevlayer)
convc = BatchNormalization()(convc)
if True == pooling:
convc = MaxPooling2D(pool_size=(2, 2))(convc)
convd = Conv2D(a, (3, 3), activation='relu', padding='same')(prevlayer)
convd = BatchNormalization()(convd)
convd = Conv2D(b, (1, 1), activation='relu', padding='same')(convd)
convd = BatchNormalization()(convd)
if True == pooling:
convd = MaxPooling2D(pool_size=(2, 2))(convd)
up = concatenate([conva, convb, convc, convd])
return up
def conv_block_simple(prevlayer, filters, prefix, strides=(1, 1)):
conv = Conv2D(filters, (3, 3), padding="same", kernel_initializer="he_normal", strides=strides, name=prefix + "_conv")(prevlayer)
conv = BatchNormalization(name=prefix + "_bn")(conv)
conv = Activation('relu', name=prefix + "_activation")(conv)
return conv
img_rows = 224
img_cols = 224
depth = 3
smooth = 1.
def dice_coef(y_true, y_pred):
y_true_f = K.flatten(y_true)
y_pred_f = K.flatten(y_pred)
intersection = K.sum(y_true_f * y_pred_f)
return (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)
def dice_coef_loss(y_true, y_pred):
return -dice_coef(y_true, y_pred)
def get_unet_plus_inception():
inputs = Input((img_rows, img_cols, depth))
conv1 = Conv2D(32, (3, 3), activation='relu', padding='same')(inputs)
conv1 = Conv2D(32, (3, 3), activation='relu', padding='same')(conv1)
pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)
conv2 = Conv2D(64, (3, 3), activation='relu', padding='same')(pool1)
conv2 = Conv2D(64, (3, 3), activation='relu', padding='same')(conv2)
pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)
conv3 = Conv2D(128, (3, 3), activation='relu', padding='same')(pool2)
conv3 = Conv2D(128, (3, 3), activation='relu', padding='same')(conv3)
pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)
conv4 = Conv2D(256, (3, 3), activation='relu', padding='same')(pool3)
conv4 = Conv2D(256, (3, 3), activation='relu', padding='same')(conv4)
pool4 = MaxPooling2D(pool_size=(2, 2))(conv4)
conv5 = Conv2D(512, (3, 3), activation='relu', padding='same')(pool4)
conv5 = Conv2D(512, (3, 3), activation='relu', padding='same')(conv5)
xx1 = block(inputs, 16, 16, False)
xx2 = block(xx1, 32, 32, True)
xx3 = block(xx2, 64, 64, True)
xx4 = block(xx3, 128, 128, True)
up6 = concatenate([Conv2DTranspose(256, (2, 2), strides=(2, 2), padding='same')(conv5), conv4, xx4], axis=3)
conv6 = Conv2D(256, (3, 3), activation='relu', padding='same')(up6)
conv6 = Conv2D(256, (3, 3), activation='relu', padding='same')(conv6)
up7 = concatenate([Conv2DTranspose(128, (2, 2), strides=(2, 2), padding='same')(conv6), conv3, xx3], axis=3)
conv7 = Conv2D(128, (3, 3), activation='relu', padding='same')(up7)
conv7 = Conv2D(128, (3, 3), activation='relu', padding='same')(conv7)
up8 = concatenate([Conv2DTranspose(64, (2, 2), strides=(2, 2), padding='same')(conv7), conv2, xx2], axis=3)
conv8 = Conv2D(64, (3, 3), activation='relu', padding='same')(up8)
conv8 = Conv2D(64, (3, 3), activation='relu', padding='same')(conv8)
up9 = concatenate([Conv2DTranspose(32, (2, 2), strides=(2, 2), padding='same')(conv8), conv1, xx1], axis=3)
conv9 = Conv2D(32, (3, 3), activation='relu', padding='same')(up9)
conv9 = Conv2D(32, (3, 3), activation='relu', padding='same')(conv9)
conv10 = Conv2D(1, (1, 1), activation='sigmoid')(conv9)
model = Model(inputs=[inputs], outputs=[conv10])
return model