From a17a14d5486a95cbfe18e7b2f1f3ce196bc9e0f1 Mon Sep 17 00:00:00 2001 From: "Brandon T. Willard" Date: Tue, 3 Dec 2019 00:50:55 -0600 Subject: [PATCH] [tensorflow] Introduce default tensor names When the default tensor string name type, `DefaultTensorName`, is used, the `TFlowMetaTensor.reify` will not assume that the string name is "strict" and rely upon the base graph to assign a unique name derived from the default one. Closes #93. --- symbolic_pymc/tensorflow/meta.py | 106 +++++++++++++++---------------- tests/tensorflow/test_meta.py | 19 +++++- tests/tensorflow/test_unify.py | 33 +++++++++- 3 files changed, 100 insertions(+), 58 deletions(-) diff --git a/symbolic_pymc/tensorflow/meta.py b/symbolic_pymc/tensorflow/meta.py index 5daefe8..fb780ab 100644 --- a/symbolic_pymc/tensorflow/meta.py +++ b/symbolic_pymc/tensorflow/meta.py @@ -48,6 +48,12 @@ tf_metatize_cache = Cache(50) +class DefaultTensorName(str): + """A type used to indicate a default tensor name.""" + + pass + + class MetaOpDefLibrary(object): """A singleton-like object that holds correspondences between TF Python API functions and the `OpDef`s they construct. @@ -366,10 +372,16 @@ def _protobuf_convert(cls, k, v): raise TypeError(f"Could not convert {k}") def __init__(self, op, name, attr, obj=None): + """Create a TF meta NodeDef. + + XXX: Meta NodeDefs with `name == None` have a special meaning; + their names are uniquely generated. We still consider them equal + (when every other property is equal, of course). + """ super().__init__(obj=obj) self.op = metatize(op) assert name is not None - self.name = name if isvar(name) else str(name) + self.name = name if isvar(name) else name if not isvar(attr): opdef_sig, _ = op_def_lib.get_op_info(self.op) @@ -600,6 +612,11 @@ def reify(self): # An operation with this name might already exist in the graph # try: + # FIXME: Lame hack + if isinstance(self.name, DefaultTensorName): + # Use a unique version of the default name. + raise KeyError() + existing_op = ops.get_default_graph().get_operation_by_name(self.name) except KeyError: # @@ -613,7 +630,15 @@ def reify(self): # An `Operation` with this name exists, let's make sure it's # equivalent to this meta `Operation` # - if self != mt(existing_op): + existing_op_mt = mt(existing_op) + + # # Since we can't exactly reproduce all NodeDef.attr information + # # (e.g. dtypes), we need to remove any unnecessary NodeDef.attr + # # fields from comparisons with same-named nodes in the graph. + # if op_attrs.keys() != node_attr.keys(): + # existing_op_mt.node_def.attr = node_attr + + if self != existing_op_mt: raise MetaReificationError( f"An Operation with the name {self.name}" " already exists in the graph and is not" @@ -725,40 +750,40 @@ def reify(self): def __truediv__(self, y): # TODO: TF performs some dtype logic (using `dtype.base_dtype`) and casting here. - return mt.realdiv(self, y, name="truediv") + return mt.realdiv(self, y, name=DefaultTensorName("truediv")) def __rtruediv__(self, x): # TODO: TF performs some dtype logic (using `dtype.base_dtype`) and casting here. - return mt.realdiv(x, self, name="truediv") + return mt.realdiv(x, self, name=DefaultTensorName("truediv")) def __add__(self, y): # TODO: If `self.dtype == tf.dtypes.string`, use `mt.add` - return mt.addv2(self, y, name="add") + return mt.addv2(self, y, name=DefaultTensorName("add")) def __radd__(self, x): # TODO: If `x.dtype == tf.dtypes.string`, use `mt.add` - return mt.addv2(x, self, name="add") + return mt.addv2(x, self, name=DefaultTensorName("add")) def __sub__(self, y): - return mt.sub(self, y, name="sub") + return mt.sub(self, y, name=DefaultTensorName("sub")) def __rsub__(self, x): - return mt.sub(x, self, name="sub") + return mt.sub(x, self, name=DefaultTensorName("sub")) def __mul__(self, y): - return mt.mul(self, y, name="mul") + return mt.mul(self, y, name=DefaultTensorName("mul")) def __rmul__(self, x): - return mt.mul(x, self, name="mul") + return mt.mul(x, self, name=DefaultTensorName("mul")) def __abs__(self): - return mt.abs(self, name="Abs") + return mt.abs(self, name=DefaultTensorName("Abs")) def __pow__(self, y): - return mt.pow(self, y, name="pow") + return mt.pow(self, y, name=DefaultTensorName("pow")) def __neg__(self): - return mt.neg(self, name="Neg") + return mt.neg(self, name=DefaultTensorName("Neg")) class TFlowMetaTensorShape(TFlowMetaSymbol): @@ -987,48 +1012,22 @@ def __api_call__(self, *args, **kwargs): if not op_args_unreified: - res_var = None - # name = op_args.get("name", None) - # - # if name is not None: - # # - # # An operation with this name might already exist in the graph - # # # - # from tensorflow.python.framework import ops + # We create the `Operation` in the graph # - # try: - # this_op = ops.get_default_graph().get_operation_by_name(name) - # except KeyError: - # pass - # else: - # # TODO: Make sure the existing `Operation` matches our arguments - # assert this_op.type == self.op_def.obj.name - # - # this_op = mt(this_op) - # op_inputs, op_node_def = self.op_args_to_operation_inputs(op_args) - # assert op_inputs == this_op.inputs - # assert op_node_def == this_op.node_def - # res_var = this_op.default_output - - if res_var is None: - # - # We create the `Operation` in the graph - # - - tf_out = self._apply_func(**op_args) - - # Ensure that the original meta objects will be available - # for use in the `metatize` that follows - tf_metatize_cache.update( - { - k: v - for k, v in zip(op_args.values(), apply_arguments.values()) - if isinstance(k, tf.Tensor) - } - ) + tf_out = self._apply_func(**op_args) + + # Ensure that the original meta objects will be available + # for use in the `metatize` that follows + tf_metatize_cache.update( + { + k: v + for k, v in zip(op_args.values(), apply_arguments.values()) + if isinstance(k, tf.Tensor) + } + ) - res_var = metatize(tf_out) + res_var = metatize(tf_out) if "names" in meta._lvar_defaults_enabled: # This should also reset the NodeDef's `obj` @@ -1073,7 +1072,8 @@ def op_args_to_operation_inputs(self, apply_arguments): node_attr = var() if "names" not in meta._lvar_defaults_enabled: - op_name = apply_arguments.get("name", op_def_tf.name) or op_def_tf.name + default_name = DefaultTensorName(op_def_tf.name) + op_name = apply_arguments.get("name", default_name) or default_name else: op_name = var() diff --git a/tests/tensorflow/test_meta.py b/tests/tensorflow/test_meta.py index 41a0d89..cfa9589 100644 --- a/tests/tensorflow/test_meta.py +++ b/tests/tensorflow/test_meta.py @@ -25,6 +25,7 @@ TFlowMetaOperator, MetaOpDefLibrary, MetaReificationError, + DefaultTensorName, mt) from tests.tensorflow import run_in_graph_mode @@ -636,7 +637,7 @@ def test_global_options(): with tf.Graph().as_default(), disable_auto_reification(): y_mt = mt.Placeholder('float') assert y_mt.obj is None - assert y_mt.name == 'Placeholder:0' + assert isinstance(y_mt.op.name, DefaultTensorName) assert isinstance(y_mt.op.node_def.attr, dict) with tf.Graph().as_default(), enable_lvar_defaults('names', 'node_attrs'): @@ -706,7 +707,7 @@ def test_meta_const(): @run_in_graph_mode def test_meta_existing_names(): - with tf.Graph().as_default(): + with tf.Graph().as_default() as test_graph: one_mt = mt(1) assert one_mt.op.name == 'Const' @@ -723,6 +724,7 @@ def test_meta_existing_names(): # Make sure it's the first base variable we created assert orig_one_tf is one_tf + # FYI: This implicitly creates 'Const_1' two_mt = mt(2) two_mt.op.node_def.name = 'Const' @@ -736,3 +738,16 @@ def test_meta_existing_names(): with pytest.raises(MetaReificationError): two_mt.reify() + + another_one_mt = TFlowMetaOperator('Const', None)(3, var()) + # The following is something that would happen as a result of + # reification (of the lvar in the meta object, not the meta object + # itself). + another_one_mt.op.node_def.attr['dtype'] = tf.int32 + + assert another_one_mt.op.name == 'Const' + assert isinstance(another_one_mt.op.name, DefaultTensorName) + # We need to make sure that the reified meta object actually uses a + # unique name. + assert isinstance(another_one_mt.reify(), tf.Tensor) + assert another_one_mt.reify().op.name == 'Const_2' diff --git a/tests/tensorflow/test_unify.py b/tests/tensorflow/test_unify.py index f29ebd6..138b1a1 100644 --- a/tests/tensorflow/test_unify.py +++ b/tests/tensorflow/test_unify.py @@ -114,6 +114,7 @@ def test_etuple_term(): # TODO FIXME: Because of the above two, this errs # add_lvar_et = etuplize(add_lvar_mt) + @run_in_graph_mode def test_basic_unify_reify(): # Test reification with manually constructed replacements @@ -127,8 +128,11 @@ def test_basic_unify_reify(): test_expr = mt.add(tf.constant(1, dtype=tf.float64), mt.mul(tf.constant(2, dtype=tf.float64), - x_l)) - test_reify_res = reify(test_expr, {x_l: a}) + x_l, name=var('mul_name')), + name=var('add_name')) + test_reify_res = reify(test_expr, {x_l: a, + var('add_name'): 'Add_10', + var('mul_name'): 'Mul_10'}) test_base_res = test_reify_res.reify() assert isinstance(test_base_res, tf.Tensor) @@ -141,7 +145,7 @@ def test_basic_unify_reify(): # Simply make sure that unification succeeds meta_expected_res = mt(expected_res) s_test = unify(test_expr, meta_expected_res, {}) - assert len(s_test) == 3 + assert len(s_test) == 5 assert reify(test_expr, s_test) == meta_expected_res @@ -199,3 +203,26 @@ def test_sexp_unify_reify(): # Now, the second, `A . y` assert z_dist_tf.op.inputs[1].op.inputs[0] == A assert z_dist_tf.op.inputs[1].op.inputs[1] == y + + +@run_in_graph_mode +@pytest.mark.xfail(strict=True) +def test_unique_names(): + + first_div_mt = mt(1) / mt(2) + + assert first_div_mt.op.name == 'truediv' + assert first_div_mt.reify().op.name + + div_lv = mt.realdiv(var('b'), var('c'), name=var('name')) + # Unify with the TF graph, then reify + s = unify(first_div_mt.reify(), div_lv) + + s[var('b')] = 1 + s[var('b')] = 3 + + div_mt = reify(div_lv, s) + + assert div_mt.op.name == 'truediv' + assert isinstance(div_mt.reify(), tf.Tensor) + assert first_div_mt.reify() != div_mt.reify()