Skip to content

Commit

Permalink
[docs] Use the normalize_tf_graph now in symbolic_pymc
Browse files Browse the repository at this point in the history
  • Loading branch information
brandonwillard committed Mar 3, 2020
1 parent 968fc01 commit 10d6e2c
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 310 deletions.
64 changes: 5 additions & 59 deletions docs/source/tensorflow-radon-example.org
Original file line number Diff line number Diff line change
Expand Up @@ -645,72 +645,18 @@ breadth of relevant operator coverage isn't clear; however, the normalizations
that it does provide are worth using, so we'll make use of them throughout.
:END:

[[grappler-normalize-function]] provides a simple means of
src_python[:eval never]{symbolic_pymc.tensorflow.graph.normalize_tf_graph} provides a simple means of
applying src_python[:eval never]{grappler}.

#+NAME: grappler-normalize-function
#+BEGIN_SRC python :exports code :results silent
from tensorflow.core.protobuf import config_pb2

from tensorflow.python.framework import ops
from tensorflow.python.framework import importer
from tensorflow.python.framework import meta_graph

from tensorflow.python.grappler import cluster
from tensorflow.python.grappler import tf_optimizer


try:
gcluster = cluster.Cluster()
except tf.errors.UnavailableError:
pass

config = config_pb2.ConfigProto()


def normalize_tf_graph(graph_output, new_graph=True, verbose=False):
"""Use grappler to normalize a graph.

Arguments
=========
graph_output: Tensor
A tensor we want to consider as "output" of a FuncGraph.

Returns
=======
The simplified graph.
"""
train_op = graph_output.graph.get_collection_ref(ops.GraphKeys.TRAIN_OP)
train_op.clear()
train_op.extend([graph_output])

metagraph = meta_graph.create_meta_graph_def(graph=graph_output.graph)

optimized_graphdef = tf_optimizer.OptimizeGraph(
config, metagraph, verbose=verbose, cluster=gcluster)

output_name = graph_output.name

if new_graph:
optimized_graph = ops.Graph()
else:
optimized_graph = ops.get_default_graph()
del graph_output

with optimized_graph.as_default():
importer.import_graph_def(optimized_graphdef, name="")

opt_graph_output = optimized_graph.get_tensor_by_name(output_name)

return opt_graph_output
#+END_SRC

In [[grappler-normalize-function]] we
In [[grappler-normalize-test-graph]] we
run src_python[:eval never]{grappler} on the log-likelihood graph for a normal
random variable from [[tfp-normal-log-lik-graph]].

#+NAME: grappler-normalize-test-graph
#+BEGIN_SRC python :exports code :results silent :wrap
from symbolic_pymc.tensorflow.graph import normalize_tf_graph


normal_log_lik_opt = normalize_tf_graph(normal_log_lik)
#+END_SRC

Expand Down
Loading

0 comments on commit 10d6e2c

Please sign in to comment.