-
Notifications
You must be signed in to change notification settings - Fork 7
/
wct.py
79 lines (57 loc) · 2.56 KB
/
wct.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
#!/usr/bin/env python2
# -*- coding: utf-8 -*-
"""
Created on Mon Mar 19 14:55:51 2018
@author: zwx
"""
import tensorflow as tf
def wct_tf(content, style, alpha, eps=1e-8):
'''TensorFlow version of Whiten-Color Transform
Assume that content/style encodings have shape 1xHxWxC
See p.4 of the Universal Style Transfer paper for corresponding equations:
https://arxiv.org/pdf/1705.08086.pdf
'''
# Remove batch dim and reorder to CxHxW
content_t = tf.transpose(tf.squeeze(content), (2, 0, 1))
style_t = tf.transpose(tf.squeeze(style), (2, 0, 1))
Cc, Hc, Wc = tf.unstack(tf.shape(content_t))
Cs, Hs, Ws = tf.unstack(tf.shape(style_t))
# CxHxW -> CxH*W
content_flat = tf.reshape(content_t, (Cc, Hc*Wc))
style_flat = tf.reshape(style_t, (Cs, Hs*Ws))
# Content covariance
mc = tf.reduce_mean(content_flat, axis=1, keep_dims=True)
fc = content_flat - mc
fcfc = tf.matmul(fc, fc, transpose_b=True) / (tf.cast(Hc*Wc, tf.float32) - 1.) + tf.eye(Cc)*eps
# Style covariance
ms = tf.reduce_mean(style_flat, axis=1, keep_dims=True)
fs = style_flat - ms
fsfs = tf.matmul(fs, fs, transpose_b=True) / (tf.cast(Hs*Ws, tf.float32) - 1.) + tf.eye(Cs)*eps
# tf.svd is slower on GPU, see https://github.com/tensorflow/tensorflow/issues/13603
with tf.device('/cpu:0'):
Sc, Uc, _ = tf.svd(fcfc)
Ss, Us, _ = tf.svd(fsfs)
# Filter small singular values
k_c = tf.reduce_sum(tf.cast(tf.greater(Sc, 1e-5), tf.int32))
k_s = tf.reduce_sum(tf.cast(tf.greater(Ss, 1e-5), tf.int32))
# Whiten content feature
Dc = tf.diag(tf.pow(Sc[:k_c], -0.5))
fc_hat = tf.matmul(tf.matmul(tf.matmul(Uc[:,:k_c], Dc), Uc[:,:k_c], transpose_b=True), fc)
# Color content with style
Ds = tf.diag(tf.pow(Ss[:k_s], 0.5))
fcs_hat = tf.matmul(tf.matmul(tf.matmul(Us[:,:k_s], Ds), Us[:,:k_s], transpose_b=True), fc_hat)
# Re-center with mean of style
fcs_hat = fcs_hat + ms
# Blend whiten-colored feature with original content feature
blended = alpha * fcs_hat + (1 - alpha) * (fc + mc)
# CxH*W -> CxHxW
blended = tf.reshape(blended, (Cc,Hc,Wc))
# CxHxW -> 1xHxWxC
blended = tf.expand_dims(tf.transpose(blended, (1,2,0)), 0)
return blended
def Adain(content,style,eps=1e-8):
mean_c, var_c = tf.nn.moments(content,axes=[1,2],keep_dims= True)
mean_s, var_s = tf.nn.moments(style,axes=[1,2],keep_dims=True)
instance_normolization = (content -mean_c) / (var_c+eps)
stylized_feature = instance_normolization*var_s+mean_s
return stylized_feature