-
Notifications
You must be signed in to change notification settings - Fork 0
/
optimize.py
150 lines (129 loc) · 5.9 KB
/
optimize.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
'''
@Descripttion:
@version:
@Author: 孙工
@Date: 2020-04-09 17:33:40
'''
from __future__ import print_function
import functools
import vgg, pdb, time
import tensorflow as tf, numpy as np, os
import transform
from utils import get_img
#风格的层
STYLE_LAYERS = ('relu1_1', 'relu2_1', 'relu3_1', 'relu4_1', 'relu5_1')
#内容的层
CONTENT_LAYER = 'relu4_2'
DEVICES = 'CUDA_VISIBLE_DEVICES'
def optimize(content_targets, style_target, content_weight, style_weight,
tv_weight, vgg_path, epochs=2, print_iterations=1000,
batch_size=4, save_path='saver/fns.ckpt', slow=False,
learning_rate=1e-3, debug=False):
if slow:
batch_size = 1
mod = len(content_targets) % batch_size
if mod > 0 :
print("训练集的数据略有减少..取整")
content_targets = content_targets[:-mod]
style_features = {}
batch_shape = (batch_size,256,256,3)
style_shape = (1,) + style_target.style_shape
# precompute style features
with tf.Graph().as_default(), tf.device('/gpu:0'), tf.Session() as sess:
style_image = tf.placeholder(tf.float32, shape=style_shape, name='style_image')
style_image_pre = vgg.preprocess(style_image)
net = vgg.net(vgg_path, style_image_pre)
style_pre = np.array([style_target])
for layer in STYLE_LAYERS:
features = net[layer].eval(feed_dict={style_image:style_pre})
features = np.reshape(features, (-1, features.shape[3]))
gram = np.matmul(features.T, features) / features.size
style_features[layer] = gram
with tf.Graph().as_default(), tf.Session() as sess:
X_content = tf.placeholder(tf.float32, shape=batch_shape, name="X_content")
X_pre = vgg.preprocess(X_content)
# 计算内容特征 precompute content features
content_features = {}
content_net = vgg.net(vgg_path, X_pre)
content_features[CONTENT_LAYER] = content_net[CONTENT_LAYER]
if slow:
preds = tf.Variable(
tf.random_normal(X_content.get_shape()) * 0.256
)
preds_pre = preds
else:
#残差网络
preds = transform.net(X_content/255.0)
preds_pre = vgg.preprocess(preds)
net = vgg.net(vgg_path, preds_pre)
content_size = _tensor_size(content_features[CONTENT_LAYER])*batch_size
assert _tensor_size(content_features[CONTENT_LAYER]) == _tensor_size(net[CONTENT_LAYER])
content_loss = content_weight * (2 * tf.nn.l2_loss(
net[CONTENT_LAYER] - content_features[CONTENT_LAYER]) / content_size
)
style_losses = []
for style_layer in STYLE_LAYERS:
layer = net[style_layer]
bs, height, width, filters = map(lambda i:i.value,layer.get_shape())
size = height * width * filters
feats = tf.reshape(layer, (bs, height * width, filters))
feats_T = tf.transpose(feats, perm=[0,2,1])
grams = tf.matmul(feats_T, feats) / size
style_gram = style_features[style_layer]
style_losses.append(2 * tf.nn.l2_loss(grams - style_gram)/style_gram.size)
style_loss = style_weight * functools.reduce(tf.add, style_losses) / batch_size
# total variation denoising
tv_y_size = _tensor_size(preds[:,1:,:,:])
tv_x_size = _tensor_size(preds[:,:,1:,:])
y_tv = tf.nn.l2_loss(preds[:,1:,:,:] - preds[:,:batch_shape[1]-1,:,:])
x_tv = tf.nn.l2_loss(preds[:,:,1:,:] - preds[:,:,:batch_shape[2]-1,:])
tv_loss = tv_weight*2*(x_tv/tv_x_size + y_tv/tv_y_size)/batch_size
loss = content_loss + style_loss + tv_loss
# overall loss
train_step = tf.train.AdamOptimizer(learning_rate).minimize(loss)
sess.run(tf.global_variables_initializer())
import random
uid = random.randint(1, 100)
print("UID: %s" % uid)
for epoch in range(epochs):
num_examples = len(content_targets)
iterations = 0
while iterations * batch_size < num_examples:
start_time = time.time()
curr = iterations * batch_size
step = curr + batch_size
X_batch = np.zeros(batch_shape, dtype=np.float32)
for j, img_p in enumerate(content_targets[curr:step]):
X_batch[j] = get_img(img_p, (256,256,3)).astype(np.float32)
iterations += 1
assert X_batch.shape[0] == batch_size
feed_dict = {
X_content:X_batch
}
train_step.run(feed_dict=feed_dict)
end_time = time.time()
delta_time = end_time - start_time
if debug:
print("UID: %s, batch time: %s" % (uid, delta_time))
is_print_iter = int(iterations) % print_iterations == 0
if slow:
is_print_iter = epoch % print_iterations == 0
is_last = epoch == epochs - 1 and iterations * batch_size >= num_examples
should_print = is_print_iter or is_last
if should_print:
to_get = [style_loss, content_loss, tv_loss, loss, preds]
test_feed_dict = {
X_content:X_batch
}
tup = sess.run(to_get, feed_dict = test_feed_dict)
_style_loss,_content_loss,_tv_loss,_loss,_preds = tup
losses = (_style_loss, _content_loss, _tv_loss, _loss)
if slow:
_preds = vgg.unprocess(_preds)
else:
saver = tf.train.Saver()
res = saver.save(sess, save_path)
yield(_preds, losses, iterations, epoch)
def _tensor_size(tensor):
from operator import mul
return functools.reduce(mul, (d.value for d in tensor.get_shape()[1:]), 1)