-
Notifications
You must be signed in to change notification settings - Fork 0
/
gen_pb.py~
executable file
·52 lines (43 loc) · 1.95 KB
/
gen_pb.py~
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import tensorflow as tf
from tensorflow.python.framework import graph_util
from tensorflow.python.platform import gfile
from google.protobuf import text_format
import os
import numpy as np
from PIL import Image
import importlib
import utli
ckpt_path = "/home/data02_disk/tao/3DFace_server/log/faceFlat3/20210713_015922/-154440"
pb_path = "/home/data02_disk/tao/3DFace_server/log/faceFlat3/pb/spc_2d_0713.pb"
model_def = "models.faceFlat3"
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
image_h_size = 112
image_w_size = 96
# with tf.device('/cpu:0'):
with tf.Session() as sess:
output_graph_txt = pb_path.replace(".pb",".pbtxt")
img_placeholder = tf.placeholder(tf.float32, [None, image_h_size, image_w_size, 1])
phase = False
network = importlib.import_module(model_def)
#logits, Predictions = network.inference(img_placeholder, phase_train=False,mode="test")
Predictions = network.inference(img_placeholder, phase_train=False, mode="test")
graph_def = tf.get_default_graph().as_graph_def()
print(graph_def)
vars_to_restore = utli.get_vars_to_restore(ckpt_path)
saver = tf.train.Saver(vars_to_restore)
#print(graph_def)
# sess.run(tf.global_variables_initializer())
# sess.run(tf.local_variables_initializer())
saver.restore(sess, ckpt_path)
# saver.restore(sess,ckpt_path)
print('load')
#output_graph_def = graph_util.convert_variables_to_constants(sess, graph_def,['mobilenet/softmax/Reshape_1']) #disparity_regression/mul_1
output_graph_def = graph_util.convert_variables_to_constants(sess, graph_def, ['mobilenet/out'])
with tf.gfile.GFile(pb_path,"wb") as f:
f.write(output_graph_def.SerializeToString())
# generation pbtxt
with gfile.FastGFile(output_graph_txt, 'wb') as f:
f.write(text_format.MessageToString(graph_def))
# output_graph_txt = pb_path+"txt"
# with gfile.FastGFile(output_graph_txt, 'wb') as f:
# f.write(text_format.MessageToString(graph_def))