Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Freeze graph #1

Merged
merged 4 commits into from
May 18, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -90,5 +90,5 @@ ENV/

logs/
Model_zoo/
res/
data/
res/
36 changes: 36 additions & 0 deletions FCN_plus.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from PIL import Image
from six.moves import xrange
from scipy import misc
from tensorflow.python.framework import graph_util

FLAGS = tf.flags.FLAGS
tf.flags.DEFINE_integer("batch_size", "5", "batch size for training")
Expand Down Expand Up @@ -235,6 +236,40 @@ def main(argv=None):
utils.save_image(pred[itr].astype(np.uint8), FLAGS.logs_dir, name="pred_" + str(5+itr))
print("Saved image: %d" % itr)'''

def freeze_graph():
graph = tf.get_default_graph()
input_graph_def = graph.as_graph_def()
output_node_names = "save/restore_all"
print ("graph node names", output_node_names)
keep_probability = tf.placeholder(tf.float32, name="keep_probabilty")
image = tf.placeholder(tf.float32, shape=[None, IMAGE_HEIGHT, IMAGE_WIDTH, 6], name="input_image")
annotation = tf.placeholder(tf.int32, shape=[None, IMAGE_HEIGHT, IMAGE_WIDTH, 1], name="annotation")

pred_annotation, logits = inference(image, keep_probability)
sft = tf.nn.softmax(logits)
test_dataset_reader = TestDataset('data/testlist.mat')
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
ckpt = tf.train.get_checkpoint_state(FLAGS.logs_dir)
input_checkpoint = ckpt.model_checkpoint_path
absolute_model_folder = "/".join(input_checkpoint.split('/')[:-1])
output_graph = absolute_model_folder + "/frozen_model.pb"
saver = tf.train.Saver()
if ckpt and ckpt.model_checkpoint_path:
saver.restore(sess, ckpt.model_checkpoint_path)
print("Model restored...")
# We use a built-in TF helper to export variables to constants
output_graph_def = graph_util.convert_variables_to_constants(
sess,# The session is used to retrieve the weights
tf.get_default_graph().as_graph_def(),# The graph_def is used to retrieve the nodes
[output_node_names]# The output node names are used to select the usefull nodes
)
# Finally we serialize and dump the output graph to the filesystem
with tf.gfile.GFile(output_graph, "wb") as f:
f.write(output_graph_def.SerializeToString())
print("%d ops in the final graph." % len(output_graph_def.node))
print("saving graph to file successful")

def pred():
keep_probability = tf.placeholder(tf.float32, name="keep_probabilty")
image = tf.placeholder(tf.float32, shape=[None, IMAGE_HEIGHT, IMAGE_WIDTH, 6], name="input_image")
Expand Down Expand Up @@ -296,3 +331,4 @@ def save_alpha_img(org, mat, name):
if __name__ == "__main__":
#tf.app.run()
pred()
#freeze_graph()