-
Notifications
You must be signed in to change notification settings - Fork 0
/
face_decoder.py
345 lines (279 loc) · 13.4 KB
/
face_decoder.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
import tensorflow as tf
import math as m
import numpy as np
from scipy.io import loadmat
import platform
is_windows = platform.system() == "Windows"
if not is_windows:
from renderer import mesh_renderer
###############################################################################################
# Reconstruct 3D face based on output coefficients and facemodel
###############################################################################################
# BFM 3D face model
class BFM():
def __init__(self,model_path = './BFM/BFM_model_front.mat'):
model = loadmat(model_path)
self.meanshape = tf.constant(model['meanshape']) # mean face shape. [3*N,1]
self.idBase = tf.constant(model['idBase']) # identity basis. [3*N,80]
self.exBase = tf.constant(model['exBase'].astype(np.float32)) # expression basis. [3*N,64]
self.meantex = tf.constant(model['meantex']) # mean face texture. [3*N,1] (0-255)
self.texBase = tf.constant(model['texBase']) # texture basis. [3*N,80]
self.point_buf = tf.constant(model['point_buf']) # face indices for each vertex that lies in. starts from 1. [N,8]
self.face_buf = tf.constant(model['tri']) # vertex indices for each face. starts from 1. [F,3]
self.front_mask_render = tf.squeeze(tf.constant(model['frontmask2_idx'])) # vertex indices for small face region to compute photometric error. starts from 1.
self.mask_face_buf = tf.constant(model['tri_mask2']) # vertex indices for each face from small face region. starts from 1. [f,3]
self.skin_mask = tf.squeeze(tf.constant(model['skinmask'])) # vertex indices for pre-defined skin region to compute reflectance loss
self.keypoints = tf.squeeze(tf.constant(model['keypoints'])) # vertex indices for 68 landmarks. starts from 1. [68,1]
# Analytic 3D face
class Face3D():
def __init__(self):
facemodel = BFM()
self.facemodel = facemodel
# analytic 3D face reconstructions with coefficients from R-Net
def Reconstruction_Block(self,coeff,opt):
#coeff: [batchsize,257] reconstruction coefficients
id_coeff,ex_coeff,tex_coeff,angles,translation,gamma,camera_scale,f_scale = self.Split_coeff(coeff)
# [batchsize,N,3] canonical face shape in BFM space
face_shape = self.Shape_formation_block(id_coeff,ex_coeff,self.facemodel)
# [batchsize,N,3] vertex texture (in RGB order)
face_texture = self.Texture_formation_block(tex_coeff,self.facemodel)
# [batchsize,3,3] rotation matrix for face shape
rotation = self.Compute_rotation_matrix(angles)
# [batchsize,N,3] vertex normal
face_norm = self.Compute_norm(face_shape,self.facemodel)
norm_r = tf.matmul(face_norm,rotation)
# do rigid transformation for face shape using predicted rotation and translation
face_shape_t = self.Rigid_transform_block(face_shape,rotation,translation)
# compute 2d landmark projections
# landmark_p: [batchsize,68,2]
face_landmark_t = self.Compute_landmark(face_shape_t,self.facemodel)
landmark_p = self.Projection_block(face_landmark_t,camera_scale,f_scale)
# [batchsize,N,3] vertex color (in RGB order)
face_color = self.Illumination_block(face_texture, norm_r, gamma)
# reconstruction images and region masks for computing photometric loss
render_imgs,img_mask,img_mask_crop = self.Render_block(face_shape_t,norm_r,face_color,camera_scale,f_scale,self.facemodel,opt.batch_size,opt.is_train)
self.id_coeff = id_coeff
self.ex_coeff = ex_coeff
self.tex_coeff = tex_coeff
self.f_scale = f_scale
self.gamma = gamma
self.face_shape = face_shape
self.face_shape_t = face_shape_t
self.face_texture = face_texture
self.face_color = face_color
self.landmark_p = landmark_p
self.render_imgs = render_imgs
self.img_mask = img_mask
self.img_mask_crop = img_mask_crop
#----------------------------------------------------------------------------------------------
def Split_coeff(self,coeff):
id_coeff = coeff[:,:80]
ex_coeff = coeff[:,80:144]
tex_coeff = coeff[:,144:224]
angles = coeff[:,224:227]
gamma = coeff[:,227:254]
translation = coeff[:,254:257]
camera_scale = tf.ones([tf.shape(coeff)[0],1])
f_scale = tf.ones([tf.shape(coeff)[0],1])
return id_coeff,ex_coeff,tex_coeff,angles,translation,gamma,camera_scale,f_scale
def Shape_formation_block(self,id_coeff,ex_coeff,facemodel):
face_shape = tf.einsum('ij,aj->ai',facemodel.idBase,id_coeff) + \
tf.einsum('ij,aj->ai',facemodel.exBase,ex_coeff) + facemodel.meanshape
# reshape face shape to [batchsize,N,3]
face_shape = tf.reshape(face_shape,[tf.shape(face_shape)[0],-1,3])
# re-centering the face shape with mean shape
face_shape = face_shape - tf.reshape(tf.reduce_mean(tf.reshape(facemodel.meanshape,[-1,3]),0),[1,1,3])
return face_shape
def Compute_norm(self,face_shape,facemodel):
shape = face_shape
face_id = facemodel.face_buf
point_id = facemodel.point_buf
# face_id and point_id index starts from 1
face_id = tf.cast(face_id - 1,tf.int32)
point_id = tf.cast(point_id - 1,tf.int32)
#compute normal for each face
v1 = tf.gather(shape,face_id[:,0], axis = 1)
v2 = tf.gather(shape,face_id[:,1], axis = 1)
v3 = tf.gather(shape,face_id[:,2], axis = 1)
e1 = v1 - v2
e2 = v2 - v3
face_norm = tf.cross(e1,e2)
face_norm = tf.nn.l2_normalize(face_norm, dim = 2) # normalized face_norm first
face_norm = tf.concat([face_norm,tf.zeros([tf.shape(face_shape)[0],1,3])], axis = 1)
#compute normal for each vertex using one-ring neighborhood
v_norm = tf.reduce_sum(tf.gather(face_norm, point_id, axis = 1), axis = 2)
v_norm = tf.nn.l2_normalize(v_norm, dim = 2)
return v_norm
def Texture_formation_block(self,tex_coeff,facemodel):
face_texture = tf.einsum('ij,aj->ai',facemodel.texBase,tex_coeff) + facemodel.meantex
# reshape face texture to [batchsize,N,3], note that texture is in RGB order
face_texture = tf.reshape(face_texture,[tf.shape(face_texture)[0],-1,3])
return face_texture
def Compute_rotation_matrix(self,angles):
n_data = tf.shape(angles)[0]
# compute rotation matrix for X-axis, Y-axis, Z-axis respectively
rotation_X = tf.concat([tf.ones([n_data,1]),
tf.zeros([n_data,3]),
tf.reshape(tf.cos(angles[:,0]),[n_data,1]),
-tf.reshape(tf.sin(angles[:,0]),[n_data,1]),
tf.zeros([n_data,1]),
tf.reshape(tf.sin(angles[:,0]),[n_data,1]),
tf.reshape(tf.cos(angles[:,0]),[n_data,1])],
axis = 1
)
rotation_Y = tf.concat([tf.reshape(tf.cos(angles[:,1]),[n_data,1]),
tf.zeros([n_data,1]),
tf.reshape(tf.sin(angles[:,1]),[n_data,1]),
tf.zeros([n_data,1]),
tf.ones([n_data,1]),
tf.zeros([n_data,1]),
-tf.reshape(tf.sin(angles[:,1]),[n_data,1]),
tf.zeros([n_data,1]),
tf.reshape(tf.cos(angles[:,1]),[n_data,1])],
axis = 1
)
rotation_Z = tf.concat([tf.reshape(tf.cos(angles[:,2]),[n_data,1]),
-tf.reshape(tf.sin(angles[:,2]),[n_data,1]),
tf.zeros([n_data,1]),
tf.reshape(tf.sin(angles[:,2]),[n_data,1]),
tf.reshape(tf.cos(angles[:,2]),[n_data,1]),
tf.zeros([n_data,3]),
tf.ones([n_data,1])],
axis = 1
)
rotation_X = tf.reshape(rotation_X,[n_data,3,3])
rotation_Y = tf.reshape(rotation_Y,[n_data,3,3])
rotation_Z = tf.reshape(rotation_Z,[n_data,3,3])
# R = RzRyRx
rotation = tf.matmul(tf.matmul(rotation_Z,rotation_Y),rotation_X)
rotation = tf.transpose(rotation, perm = [0,2,1])
return rotation
def Projection_block(self,face_shape,camera_scale,f_scale):
# pre-defined camera focal for pespective projection
focal = tf.constant(1015.0)
focal = focal*f_scale
focal = tf.reshape(focal,[-1,1])
batchsize = tf.shape(focal)[0]
# define camera position
camera_pos = tf.reshape(tf.constant([0.0,0.0,10.0]),[1,1,3])*tf.reshape(camera_scale,[-1,1,1])
reverse_z = tf.tile(tf.reshape(tf.constant([1.0,0,0,0,1,0,0,0,-1.0]),[1,3,3]),[tf.shape(face_shape)[0],1,1])
# compute projection matrix
p_matrix = tf.concat([focal,tf.zeros([batchsize,1]),112.*tf.ones([batchsize,1]),tf.zeros([batchsize,1]),focal,112.*tf.ones([batchsize,1]),tf.zeros([batchsize,2]),tf.ones([batchsize,1])],axis = 1)
p_matrix = tf.reshape(p_matrix,[-1,3,3])
# convert z in world space to the distance to camera
face_shape = tf.matmul(face_shape,reverse_z) + camera_pos
aug_projection = tf.matmul(face_shape,tf.transpose(p_matrix,[0,2,1]))
# [batchsize, N,2] 2d face projection
face_projection = aug_projection[:,:,0:2]/tf.reshape(aug_projection[:,:,2],[tf.shape(face_shape)[0],tf.shape(aug_projection)[1],1])
return face_projection
def Compute_landmark(self,face_shape,facemodel):
# compute 3D landmark postitions with pre-computed 3D face shape
keypoints_idx = facemodel.keypoints
keypoints_idx = tf.cast(keypoints_idx - 1,tf.int32)
face_landmark = tf.gather(face_shape,keypoints_idx,axis = 1)
return face_landmark
def Illumination_block(self,face_texture,norm_r,gamma):
n_data = tf.shape(gamma)[0]
n_point = tf.shape(norm_r)[1]
gamma = tf.reshape(gamma,[n_data,3,9])
# set initial lighting with an ambient lighting
init_lit = tf.constant([0.8,0,0,0,0,0,0,0,0])
gamma = gamma + tf.reshape(init_lit,[1,1,9])
# compute vertex color using SH function approximation
a0 = m.pi
a1 = 2*m.pi/tf.sqrt(3.0)
a2 = 2*m.pi/tf.sqrt(8.0)
c0 = 1/tf.sqrt(4*m.pi)
c1 = tf.sqrt(3.0)/tf.sqrt(4*m.pi)
c2 = 3*tf.sqrt(5.0)/tf.sqrt(12*m.pi)
Y = tf.concat([tf.tile(tf.reshape(a0*c0,[1,1,1]),[n_data,n_point,1]),
tf.expand_dims(-a1*c1*norm_r[:,:,1],2),
tf.expand_dims(a1*c1*norm_r[:,:,2],2),
tf.expand_dims(-a1*c1*norm_r[:,:,0],2),
tf.expand_dims(a2*c2*norm_r[:,:,0]*norm_r[:,:,1],2),
tf.expand_dims(-a2*c2*norm_r[:,:,1]*norm_r[:,:,2],2),
tf.expand_dims(a2*c2*0.5/tf.sqrt(3.0)*(3*tf.square(norm_r[:,:,2])-1),2),
tf.expand_dims(-a2*c2*norm_r[:,:,0]*norm_r[:,:,2],2),
tf.expand_dims(a2*c2*0.5*(tf.square(norm_r[:,:,0])-tf.square(norm_r[:,:,1])),2)],axis = 2)
color_r = tf.squeeze(tf.matmul(Y,tf.expand_dims(gamma[:,0,:],2)),axis = 2)
color_g = tf.squeeze(tf.matmul(Y,tf.expand_dims(gamma[:,1,:],2)),axis = 2)
color_b = tf.squeeze(tf.matmul(Y,tf.expand_dims(gamma[:,2,:],2)),axis = 2)
#[batchsize,N,3] vertex color in RGB order
face_color = tf.stack([color_r*face_texture[:,:,0],color_g*face_texture[:,:,1],color_b*face_texture[:,:,2]],axis = 2)
return face_color
def Rigid_transform_block(self,face_shape,rotation,translation):
# do rigid transformation for 3D face shape
face_shape_r = tf.matmul(face_shape,rotation)
face_shape_t = face_shape_r + tf.reshape(translation,[tf.shape(face_shape)[0],1,3])
return face_shape_t
def Render_block(self,face_shape,face_norm,face_color,camera_scale,f_scale,facemodel,batchsize,is_train=True):
if is_train and is_windows:
raise ValueError('Not support training with Windows environment.')
if is_windows:
return [],[],[]
# render reconstruction images
n_vex = int(facemodel.idBase.shape[0].value/3)
fov_y = 2*tf.atan(112./(1015.*f_scale))*180./m.pi
fov_y = tf.reshape(fov_y,[batchsize])
# full face region
face_shape = tf.reshape(face_shape,[batchsize,n_vex,3])
face_norm = tf.reshape(face_norm,[batchsize,n_vex,3])
face_color = tf.reshape(face_color,[batchsize,n_vex,3])
# pre-defined cropped face region
mask_face_shape = tf.gather(face_shape,tf.cast(facemodel.front_mask_render-1,tf.int32),axis = 1)
mask_face_norm = tf.gather(face_norm,tf.cast(facemodel.front_mask_render-1,tf.int32),axis = 1)
mask_face_color = tf.gather(face_color,tf.cast(facemodel.front_mask_render-1,tf.int32),axis = 1)
# setting cammera settings
camera_position = tf.constant([[0,0,10.0]])*tf.reshape(camera_scale,[-1,1])
camera_lookat = tf.constant([0,0,0.0])
camera_up = tf.constant([0,1.0,0])
# setting light source position(intensities are set to 0 because we have computed the vertex color)
light_positions = tf.tile(tf.reshape(tf.constant([0,0,1e5]),[1,1,3]),[batchsize,1,1])
light_intensities = tf.tile(tf.reshape(tf.constant([0.0,0.0,0.0]),[1,1,3]),[batchsize,1,1])
ambient_color = tf.tile(tf.reshape(tf.constant([1.0,1,1]),[1,3]),[batchsize,1])
#using tf_mesh_renderer for rasterization (https://github.com/google/tf_mesh_renderer)
# img: [batchsize,224,224,3] images in RGB order (0-255)
# mask:[batchsize,224,224,1] transparency for img ({0,1} value)
with tf.device('/cpu:0'):
img_rgba = mesh_renderer.mesh_renderer(face_shape,
tf.cast(facemodel.face_buf-1,tf.int32),
face_norm,
face_color,
camera_position = camera_position,
camera_lookat = camera_lookat,
camera_up = camera_up,
light_positions = light_positions,
light_intensities = light_intensities,
image_width = 224,
image_height = 224,
fov_y = fov_y,
near_clip = 0.01,
far_clip = 50.0,
ambient_color = ambient_color)
img = img_rgba[:,:,:,:3]
mask = img_rgba[:,:,:,3:]
img = tf.cast(img[:,:,:,::-1],tf.float32) #transfer RGB to BGR
mask = tf.cast(mask,tf.float32) # full face region
if is_train:
# compute mask for small face region
with tf.device('/cpu:0'):
img_crop_rgba = mesh_renderer.mesh_renderer(mask_face_shape,
tf.cast(facemodel.mask_face_buf-1,tf.int32),
mask_face_norm,
mask_face_color,
camera_position = camera_position,
camera_lookat = camera_lookat,
camera_up = camera_up,
light_positions = light_positions,
light_intensities = light_intensities,
image_width = 224,
image_height = 224,
fov_y = fov_y,
near_clip = 0.01,
far_clip = 50.0,
ambient_color = ambient_color)
mask_f = img_crop_rgba[:,:,:,3:]
mask_f = tf.cast(mask_f,tf.float32) # small face region
return img,mask,mask_f
img_rgba = tf.cast(tf.clip_by_value(img_rgba,0,255),tf.float32)
return img_rgba,mask,mask