Skip to content

Commit

Permalink
[tensorflow] Introduce default tensor names
Browse files Browse the repository at this point in the history
When the default tensor string name type, `DefaultTensorName`, is used, the
`TFlowMetaTensor.reify` will not assume that the string name is "strict" and
rely upon the base graph to assign a unique name derived from the default one.

Closes #93.
  • Loading branch information
brandonwillard committed Dec 4, 2019
1 parent b504958 commit a17a14d
Show file tree
Hide file tree
Showing 3 changed files with 100 additions and 58 deletions.
106 changes: 53 additions & 53 deletions symbolic_pymc/tensorflow/meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,12 @@
tf_metatize_cache = Cache(50)


class DefaultTensorName(str):
"""A type used to indicate a default tensor name."""

pass


class MetaOpDefLibrary(object):
"""A singleton-like object that holds correspondences between TF Python API functions and the `OpDef`s they construct.
Expand Down Expand Up @@ -366,10 +372,16 @@ def _protobuf_convert(cls, k, v):
raise TypeError(f"Could not convert {k}")

def __init__(self, op, name, attr, obj=None):
"""Create a TF meta NodeDef.
XXX: Meta NodeDefs with `name == None` have a special meaning;
their names are uniquely generated. We still consider them equal
(when every other property is equal, of course).
"""
super().__init__(obj=obj)
self.op = metatize(op)
assert name is not None
self.name = name if isvar(name) else str(name)
self.name = name if isvar(name) else name

if not isvar(attr):
opdef_sig, _ = op_def_lib.get_op_info(self.op)
Expand Down Expand Up @@ -600,6 +612,11 @@ def reify(self):
# An operation with this name might already exist in the graph
#
try:
# FIXME: Lame hack
if isinstance(self.name, DefaultTensorName):
# Use a unique version of the default name.
raise KeyError()

existing_op = ops.get_default_graph().get_operation_by_name(self.name)
except KeyError:
#
Expand All @@ -613,7 +630,15 @@ def reify(self):
# An `Operation` with this name exists, let's make sure it's
# equivalent to this meta `Operation`
#
if self != mt(existing_op):
existing_op_mt = mt(existing_op)

# # Since we can't exactly reproduce all NodeDef.attr information
# # (e.g. dtypes), we need to remove any unnecessary NodeDef.attr
# # fields from comparisons with same-named nodes in the graph.
# if op_attrs.keys() != node_attr.keys():
# existing_op_mt.node_def.attr = node_attr

if self != existing_op_mt:
raise MetaReificationError(
f"An Operation with the name {self.name}"
" already exists in the graph and is not"
Expand Down Expand Up @@ -725,40 +750,40 @@ def reify(self):

def __truediv__(self, y):
# TODO: TF performs some dtype logic (using `dtype.base_dtype`) and casting here.
return mt.realdiv(self, y, name="truediv")
return mt.realdiv(self, y, name=DefaultTensorName("truediv"))

def __rtruediv__(self, x):
# TODO: TF performs some dtype logic (using `dtype.base_dtype`) and casting here.
return mt.realdiv(x, self, name="truediv")
return mt.realdiv(x, self, name=DefaultTensorName("truediv"))

def __add__(self, y):
# TODO: If `self.dtype == tf.dtypes.string`, use `mt.add`
return mt.addv2(self, y, name="add")
return mt.addv2(self, y, name=DefaultTensorName("add"))

def __radd__(self, x):
# TODO: If `x.dtype == tf.dtypes.string`, use `mt.add`
return mt.addv2(x, self, name="add")
return mt.addv2(x, self, name=DefaultTensorName("add"))

def __sub__(self, y):
return mt.sub(self, y, name="sub")
return mt.sub(self, y, name=DefaultTensorName("sub"))

def __rsub__(self, x):
return mt.sub(x, self, name="sub")
return mt.sub(x, self, name=DefaultTensorName("sub"))

def __mul__(self, y):
return mt.mul(self, y, name="mul")
return mt.mul(self, y, name=DefaultTensorName("mul"))

def __rmul__(self, x):
return mt.mul(x, self, name="mul")
return mt.mul(x, self, name=DefaultTensorName("mul"))

def __abs__(self):
return mt.abs(self, name="Abs")
return mt.abs(self, name=DefaultTensorName("Abs"))

def __pow__(self, y):
return mt.pow(self, y, name="pow")
return mt.pow(self, y, name=DefaultTensorName("pow"))

def __neg__(self):
return mt.neg(self, name="Neg")
return mt.neg(self, name=DefaultTensorName("Neg"))


class TFlowMetaTensorShape(TFlowMetaSymbol):
Expand Down Expand Up @@ -987,48 +1012,22 @@ def __api_call__(self, *args, **kwargs):

if not op_args_unreified:

res_var = None
# name = op_args.get("name", None)
#
# if name is not None:
# #
# # An operation with this name might already exist in the graph
# #
#
# from tensorflow.python.framework import ops
# We create the `Operation` in the graph
#
# try:
# this_op = ops.get_default_graph().get_operation_by_name(name)
# except KeyError:
# pass
# else:
# # TODO: Make sure the existing `Operation` matches our arguments
# assert this_op.type == self.op_def.obj.name
#
# this_op = mt(this_op)
# op_inputs, op_node_def = self.op_args_to_operation_inputs(op_args)
# assert op_inputs == this_op.inputs
# assert op_node_def == this_op.node_def
# res_var = this_op.default_output

if res_var is None:
#
# We create the `Operation` in the graph
#

tf_out = self._apply_func(**op_args)

# Ensure that the original meta objects will be available
# for use in the `metatize` that follows
tf_metatize_cache.update(
{
k: v
for k, v in zip(op_args.values(), apply_arguments.values())
if isinstance(k, tf.Tensor)
}
)
tf_out = self._apply_func(**op_args)

# Ensure that the original meta objects will be available
# for use in the `metatize` that follows
tf_metatize_cache.update(
{
k: v
for k, v in zip(op_args.values(), apply_arguments.values())
if isinstance(k, tf.Tensor)
}
)

res_var = metatize(tf_out)
res_var = metatize(tf_out)

if "names" in meta._lvar_defaults_enabled:
# This should also reset the NodeDef's `obj`
Expand Down Expand Up @@ -1073,7 +1072,8 @@ def op_args_to_operation_inputs(self, apply_arguments):
node_attr = var()

if "names" not in meta._lvar_defaults_enabled:
op_name = apply_arguments.get("name", op_def_tf.name) or op_def_tf.name
default_name = DefaultTensorName(op_def_tf.name)
op_name = apply_arguments.get("name", default_name) or default_name
else:
op_name = var()

Expand Down
19 changes: 17 additions & 2 deletions tests/tensorflow/test_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
TFlowMetaOperator,
MetaOpDefLibrary,
MetaReificationError,
DefaultTensorName,
mt)

from tests.tensorflow import run_in_graph_mode
Expand Down Expand Up @@ -636,7 +637,7 @@ def test_global_options():
with tf.Graph().as_default(), disable_auto_reification():
y_mt = mt.Placeholder('float')
assert y_mt.obj is None
assert y_mt.name == 'Placeholder:0'
assert isinstance(y_mt.op.name, DefaultTensorName)
assert isinstance(y_mt.op.node_def.attr, dict)

with tf.Graph().as_default(), enable_lvar_defaults('names', 'node_attrs'):
Expand Down Expand Up @@ -706,7 +707,7 @@ def test_meta_const():
@run_in_graph_mode
def test_meta_existing_names():

with tf.Graph().as_default():
with tf.Graph().as_default() as test_graph:
one_mt = mt(1)
assert one_mt.op.name == 'Const'

Expand All @@ -723,6 +724,7 @@ def test_meta_existing_names():
# Make sure it's the first base variable we created
assert orig_one_tf is one_tf

# FYI: This implicitly creates 'Const_1'
two_mt = mt(2)
two_mt.op.node_def.name = 'Const'

Expand All @@ -736,3 +738,16 @@ def test_meta_existing_names():

with pytest.raises(MetaReificationError):
two_mt.reify()

another_one_mt = TFlowMetaOperator('Const', None)(3, var())
# The following is something that would happen as a result of
# reification (of the lvar in the meta object, not the meta object
# itself).
another_one_mt.op.node_def.attr['dtype'] = tf.int32

assert another_one_mt.op.name == 'Const'
assert isinstance(another_one_mt.op.name, DefaultTensorName)
# We need to make sure that the reified meta object actually uses a
# unique name.
assert isinstance(another_one_mt.reify(), tf.Tensor)
assert another_one_mt.reify().op.name == 'Const_2'
33 changes: 30 additions & 3 deletions tests/tensorflow/test_unify.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ def test_etuple_term():
# TODO FIXME: Because of the above two, this errs
# add_lvar_et = etuplize(add_lvar_mt)


@run_in_graph_mode
def test_basic_unify_reify():
# Test reification with manually constructed replacements
Expand All @@ -127,8 +128,11 @@ def test_basic_unify_reify():

test_expr = mt.add(tf.constant(1, dtype=tf.float64),
mt.mul(tf.constant(2, dtype=tf.float64),
x_l))
test_reify_res = reify(test_expr, {x_l: a})
x_l, name=var('mul_name')),
name=var('add_name'))
test_reify_res = reify(test_expr, {x_l: a,
var('add_name'): 'Add_10',
var('mul_name'): 'Mul_10'})
test_base_res = test_reify_res.reify()
assert isinstance(test_base_res, tf.Tensor)

Expand All @@ -141,7 +145,7 @@ def test_basic_unify_reify():
# Simply make sure that unification succeeds
meta_expected_res = mt(expected_res)
s_test = unify(test_expr, meta_expected_res, {})
assert len(s_test) == 3
assert len(s_test) == 5

assert reify(test_expr, s_test) == meta_expected_res

Expand Down Expand Up @@ -199,3 +203,26 @@ def test_sexp_unify_reify():
# Now, the second, `A . y`
assert z_dist_tf.op.inputs[1].op.inputs[0] == A
assert z_dist_tf.op.inputs[1].op.inputs[1] == y


@run_in_graph_mode
@pytest.mark.xfail(strict=True)
def test_unique_names():

first_div_mt = mt(1) / mt(2)

assert first_div_mt.op.name == 'truediv'
assert first_div_mt.reify().op.name

div_lv = mt.realdiv(var('b'), var('c'), name=var('name'))
# Unify with the TF graph, then reify
s = unify(first_div_mt.reify(), div_lv)

s[var('b')] = 1
s[var('b')] = 3

div_mt = reify(div_lv, s)

assert div_mt.op.name == 'truediv'
assert isinstance(div_mt.reify(), tf.Tensor)
assert first_div_mt.reify() != div_mt.reify()

0 comments on commit a17a14d

Please sign in to comment.