From b5cc693169fb24d70b6e45d2d0d374d47e1582fd Mon Sep 17 00:00:00 2001 From: Sandro Braun Date: Wed, 15 Apr 2020 16:38:13 +0200 Subject: [PATCH 1/8] Adds wrapper for torch_from_tensorflow --- README.md | 33 +++++++++++++++++++++++ tests/test_adapters.py | 54 ++++++++++++++++++++++++++++++++++++++ tfpyth/__init__.py | 59 +++++++++++++++++++++++++++++++++++++++++- 3 files changed, 145 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 03a40e4..d2685ae 100644 --- a/README.md +++ b/README.md @@ -48,12 +48,43 @@ x.backward() assert np.allclose((a.grad, b.grad), (3., 24.)) ``` +or simply wrap an existing tensorflow function + +```python +def tf_function(a, b): + c = 3 * a + 4 * b * b + + return c + +session = tf.compat.v1.Session() +f = tfpyth.wrap_torch_from_tensorflow( + tf_function, ["a", "b"], None, session=session + ) # automatically creates placeholders inside + +a_ = th.tensor(1, dtype=th.float32, requires_grad=True) +b_ = th.tensor(3, dtype=th.float32, requires_grad=True) +x = f(a_, b_) + +assert x == 39.0 + +x.backward() + +assert np.allclose((a_.grad, b_.grad), (3.0, 24.0)) +``` + +* see `tests` for more examples + + ## What it's got ### `torch_from_tensorflow` Creates a PyTorch function that is differentiable by evaluating a TensorFlow output tensor given input placeholders. +### `wrap_torch_from_tensorflow` + +Wrap a TensorFlow function into a PyTorch function and automatically create placeholders + ### `eager_tensorflow_from_torch` Creates an eager Tensorflow function from a PyTorch function. @@ -62,6 +93,8 @@ Creates an eager Tensorflow function from a PyTorch function. Creates a TensorFlow op/tensor from a PyTorch function. + + ## Future work - [ ] support JAX diff --git a/tests/test_adapters.py b/tests/test_adapters.py index ab0f635..680ae33 100644 --- a/tests/test_adapters.py +++ b/tests/test_adapters.py @@ -55,3 +55,57 @@ def get_tf_function(): x.backward() assert np.allclose((a_.grad, b_.grad), (3.0, 24.0)) + + +class Test_wrap_torch_from_tensorflow: + def test_image_operation(self): + def tensorflow_function(a, size=(128, 128)): + return tf.image.resize(a, size=size) + + from functools import partial + + session = tf.compat.v1.Session() + tf_func = partial(tensorflow_function, size=(128, 128)) + f_pt = tfpyth.wrap_torch_from_tensorflow(tf_func, ["a"], [(None, 64, 64, 1)], session) + x = th.ones((1, 64, 64, 1), dtype=th.float32) + y = f_pt(x) + assert y.shape == (1, 128, 128, 1) + + def test_no_gradient_operation(self): + def tensorflow_function(a, size=(128, 128)): + return tf.image.resize(a, size=size) + + from functools import partial + + session = tf.compat.v1.Session() + tf_func = partial(tensorflow_function, size=(128, 128)) + f_pt = tfpyth.wrap_torch_from_tensorflow(tf_func, ["a"], [(None, 64, 64, 1)], session) + x = th.ones((1, 64, 64, 1), dtype=th.float32, requires_grad=False) + conv = th.nn.Conv2d(1, 1, 1) + x = conv(tfpyth.th_2D_channels_last_to_first(x)) + x = tfpyth.th_2D_channels_first_to_last(x) + y = f_pt(x) + + assert y.shape == (1, 128, 128, 1) + assert y.sum().backward() is None + assert conv.bias.grad + + def test_tensorflow_in_pytorch(self): + session = tf.compat.v1.Session() + + def get_tf_function(a, b): + c = 3 * a + 4 * b * b + + return c + + session = tf.compat.v1.Session() + f = tfpyth.wrap_torch_from_tensorflow(get_tf_function, ["a", "b"], None, session=session) + a_ = th.tensor(1, dtype=th.float32, requires_grad=True) + b_ = th.tensor(3, dtype=th.float32, requires_grad=True) + x = f(a_, b_) + + assert x == 39.0 + + x.backward() + + assert np.allclose((a_.grad, b_.grad), (3.0, 24.0)) diff --git a/tfpyth/__init__.py b/tfpyth/__init__.py index 4fcee7f..0c0f59c 100644 --- a/tfpyth/__init__.py +++ b/tfpyth/__init__.py @@ -56,7 +56,7 @@ def forward(ctx, *args): # See https://www.janfreyberg.com/blog/2019-04-01-testing-pytorch-functions/ for why "no cover" @staticmethod - def backward(ctx, grad_output): # pragma: no cover + def backward(ctx, grad_output): # pragma: no cover th_inputs = ctx.saved_tensors feed_dict = {} @@ -69,6 +69,39 @@ def backward(ctx, grad_output): # pragma: no cover return _TensorFlowFunction() +def wrap_torch_from_tensorflow(func, tensor_inputs, input_shapes=None, session=None): + """wrap func using `torch_from_tensorflow` and automatically create placeholders. + + By default, placeholders are assumed to be `tf.float32`. + + :param func: Callable. + Tensorflow function to evaluate + :param tensor_input: List[str] + List of argument names to `func` that represent a tensor input. + :param input_shapes: List[Tuple[Int]]. + Shapes of input tensors if known. Some operations require these, such as all `tf.image.resize`. + Basically these values are fed to `tf.placeholder`, so you can indicate unknown parameters using `(None, 64, 64, 1)`, for instance. + :param session: tf.compat.v1.Session + A session. If None, will instantiate new session. + + """ + if session: + session = tf.compat.v1.Session() + if input_shapes is not None: + if len(tensor_inputs) != len(input_shapes): + raise ValueError("Number of tensor inputs does not match number of input shapes") + else: + placeholders = { + arg_name: tf.compat.v1.placeholder(tf.float32, shape=shape, name=arg_name) + for arg_name, shape in zip(tensor_inputs, input_shapes) + } + else: + placeholders = {arg_name: tf.compat.v1.placeholder(tf.float32, name=arg_name) for arg_name in tensor_inputs} + output = func(**placeholders) + f = torch_from_tensorflow(session, [placeholders[t] for t in tensor_inputs], output).apply + return f + + def eager_tensorflow_from_torch(func): """ Wraps a PyTorch function into a TensorFlow eager-mode function (ie can be executed within Tensorflow eager-mode). @@ -106,3 +139,27 @@ def tensorflow_from_torch(func, inp, Tout, name=None): eager_compute = eager_tensorflow_from_torch(func) return tf.py_function(eager_compute, inp, Tout, name=name) + + +def tf_NCHW_to_NHWC(x): + return tf.transpose(x, (0, 2, 3, 1)) + + +def tf_NHWC_to_NCHW(x): + return tf.transpose(x, (0, 3, 1, 2)) + + +tf_2D_channels_first_to_last = tf_NCHW_to_NHWC +tf_2D_channels_last_to_first = tf_NHWC_to_NCHW + + +def th_NCHW_to_NHWC(x): + return x.permute((0, 2, 3, 1)) + + +def th_NHWC_to_NCHW(x): + return x.permute((0, 3, 1, 2)) + + +th_2D_channels_last_to_first = th_NHWC_to_NCHW +th_2D_channels_first_to_last = th_NCHW_to_NHWC From d76b52b8e44b54162bce69313bb70aae75a36109 Mon Sep 17 00:00:00 2001 From: Sandro Braun Date: Wed, 15 Apr 2020 16:45:57 +0200 Subject: [PATCH 2/8] adds automatic detection of tensorflow variable names from function declaration --- tests/test_adapters.py | 20 ++++++++++++++++++++ tfpyth/__init__.py | 7 +++++-- 2 files changed, 25 insertions(+), 2 deletions(-) diff --git a/tests/test_adapters.py b/tests/test_adapters.py index 680ae33..72fd0dd 100644 --- a/tests/test_adapters.py +++ b/tests/test_adapters.py @@ -109,3 +109,23 @@ def get_tf_function(a, b): x.backward() assert np.allclose((a_.grad, b_.grad), (3.0, 24.0)) + + def test_autodetect_varnames(self): + session = tf.compat.v1.Session() + + def get_tf_function(a, b): + c = 3 * a + 4 * b * b + + return c + + session = tf.compat.v1.Session() + f = tfpyth.wrap_torch_from_tensorflow(get_tf_function) + a_ = th.tensor(1, dtype=th.float32, requires_grad=True) + b_ = th.tensor(3, dtype=th.float32, requires_grad=True) + x = f(a_, b_) + + assert x == 39.0 + + x.backward() + + assert np.allclose((a_.grad, b_.grad), (3.0, 24.0)) diff --git a/tfpyth/__init__.py b/tfpyth/__init__.py index 0c0f59c..6007f54 100644 --- a/tfpyth/__init__.py +++ b/tfpyth/__init__.py @@ -69,7 +69,7 @@ def backward(ctx, grad_output): # pragma: no cover return _TensorFlowFunction() -def wrap_torch_from_tensorflow(func, tensor_inputs, input_shapes=None, session=None): +def wrap_torch_from_tensorflow(func, tensor_inputs=None, input_shapes=None, session=None): """wrap func using `torch_from_tensorflow` and automatically create placeholders. By default, placeholders are assumed to be `tf.float32`. @@ -85,8 +85,11 @@ def wrap_torch_from_tensorflow(func, tensor_inputs, input_shapes=None, session=N A session. If None, will instantiate new session. """ - if session: + if session is None: session = tf.compat.v1.Session() + if tensor_inputs is None: + tensor_inputs = func.__code__.co_varnames[: func.__code__.co_argcount] + if input_shapes is not None: if len(tensor_inputs) != len(input_shapes): raise ValueError("Number of tensor inputs does not match number of input shapes") From 6cb9d13faad3b6d2f5fc980f23d2e3d1883f0e96 Mon Sep 17 00:00:00 2001 From: Sandro Braun Date: Wed, 15 Apr 2020 16:50:39 +0200 Subject: [PATCH 3/8] docstring and readme --- README.md | 12 ++++++++++-- tfpyth/__init__.py | 1 + 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index d2685ae..9a650bb 100644 --- a/README.md +++ b/README.md @@ -58,8 +58,16 @@ def tf_function(a, b): session = tf.compat.v1.Session() f = tfpyth.wrap_torch_from_tensorflow( - tf_function, ["a", "b"], None, session=session - ) # automatically creates placeholders inside + tf_function, ["a", "b"], session=session + ) +# or simpler +f = tfpyth.wrap_torch_from_tensorflow( + tf_function, session=session + ) # automatically creates placeholders for "a" and "b" inside +# or even simpler +f = tfpyth.wrap_torch_from_tensorflow( + tf_function + ) # automatically creates placeholders for "a" and "b" and session a_ = th.tensor(1, dtype=th.float32, requires_grad=True) b_ = th.tensor(3, dtype=th.float32, requires_grad=True) diff --git a/tfpyth/__init__.py b/tfpyth/__init__.py index 6007f54..8e4ff52 100644 --- a/tfpyth/__init__.py +++ b/tfpyth/__init__.py @@ -78,6 +78,7 @@ def wrap_torch_from_tensorflow(func, tensor_inputs=None, input_shapes=None, sess Tensorflow function to evaluate :param tensor_input: List[str] List of argument names to `func` that represent a tensor input. + if not provided will interpret all arguments from func as tensorflow placeholders. :param input_shapes: List[Tuple[Int]]. Shapes of input tensors if known. Some operations require these, such as all `tf.image.resize`. Basically these values are fed to `tf.placeholder`, so you can indicate unknown parameters using `(None, 64, 64, 1)`, for instance. From 4ec56a506f783e3cb2382c892b3ca25addcd6381 Mon Sep 17 00:00:00 2001 From: Sandro Braun Date: Thu, 16 Apr 2020 11:27:53 +0200 Subject: [PATCH 4/8] adds support for multiple outputs in torch_from_tensorflow --- tests/test_adapters.py | 75 ++++++++++++++++++++++++++++++++++++++++++ tfpyth/__init__.py | 71 +++++++++++++++++++++++---------------- 2 files changed, 118 insertions(+), 28 deletions(-) diff --git a/tests/test_adapters.py b/tests/test_adapters.py index 72fd0dd..47833a0 100644 --- a/tests/test_adapters.py +++ b/tests/test_adapters.py @@ -57,6 +57,60 @@ def get_tf_function(): assert np.allclose((a_.grad, b_.grad), (3.0, 24.0)) +class Test_tensorflow_in_pytorch: + def test_single_output(self): + session = tf.Session() + + def get_tf_function(): + a = tf.placeholder(tf.float32, name="a") + b = tf.placeholder(tf.float32, name="b") + c = 3 * a + 4 * b * b + + f = tfpyth.torch_from_tensorflow(session, [a, b], c).apply + return f + + f = get_tf_function() + a_ = th.tensor(1, dtype=th.float32, requires_grad=True) + b_ = th.tensor(3, dtype=th.float32, requires_grad=True) + x = f(a_, b_) + + assert x == 39.0 + + x.backward() + + assert np.allclose((a_.grad, b_.grad), (3.0, 24.0)) + + def test_multiple_outputs(self): + session = tf.Session() + + def get_tf_function(): + a = tf.placeholder(tf.float32, name="a") + b = tf.placeholder(tf.float32, name="b") + c = 3 * a + 4 * b * b + d = 6 * a + 8 * b ** 2 + + f = tfpyth.torch_from_tensorflow(session, [a, b], [c, d]) + f1, f2 = [ff.apply for ff in f] + return f1, f2 + + f1, f2 = get_tf_function() + + def f(a, b): + return f1(a, b), f2(a, b) + + a_ = th.tensor(1, dtype=th.float32, requires_grad=True) + b_ = th.tensor(3, dtype=th.float32, requires_grad=True) + x1, x2 = f(a_, b_) + + assert x1 == 39.0 + assert x2 == 78.0 + + x1.backward() + x2.backward() + + assert np.allclose((a_.grad, b_.grad), (9.0, 72.0)) + + class Test_wrap_torch_from_tensorflow: def test_image_operation(self): def tensorflow_function(a, size=(128, 128)): @@ -110,6 +164,27 @@ def get_tf_function(a, b): assert np.allclose((a_.grad, b_.grad), (3.0, 24.0)) + def test_multiple_outputs(self): + session = tf.compat.v1.Session() + + def get_tf_function(a, b): + c = 3 * a + 4 * b * b + d = 6 * a + 8 * b ** 2 + + return c, d + + session = tf.compat.v1.Session() + f = tfpyth.wrap_torch_from_tensorflow(get_tf_function, ["a", "b"], None, session=session) + a_ = th.tensor(1, dtype=th.float32, requires_grad=True) + b_ = th.tensor(3, dtype=th.float32, requires_grad=True) + x = f(a_, b_) + + assert x == 39.0 + + x.backward() + + assert np.allclose((a_.grad, b_.grad), (3.0, 24.0)) + def test_autodetect_varnames(self): session = tf.compat.v1.Session() diff --git a/tfpyth/__init__.py b/tfpyth/__init__.py index 8e4ff52..fafb418 100644 --- a/tfpyth/__init__.py +++ b/tfpyth/__init__.py @@ -1,5 +1,6 @@ import tensorflow as tf import torch as th +import functools class TensorFlowFunction(th.autograd.Function): @@ -13,7 +14,7 @@ class TensorFlowFunction(th.autograd.Function): gradient_outputs = None -def torch_from_tensorflow(tf_session, tf_inputs, tf_output, tf_dtype=tf.float32): +def torch_from_tensorflow(tf_session, tf_inputs, tf_outputs, tf_dtype=tf.float32): """ Create a PyTorch TensorFlowFunction with forward and backward methods which executes evaluates the passed TensorFlow tensors. @@ -31,42 +32,54 @@ def torch_from_tensorflow(tf_session, tf_inputs, tf_output, tf_dtype=tf.float32) :return: TensorflowFunction which can be applied to PyTorch tensors. """ # create gradient placeholders - tf_gradient_placeholder = tf.placeholder(dtype=tf_dtype, name=f"gradient") - tf_gradient_outputs = tf.gradients( - ys=tf_output, xs=tf_inputs, grad_ys=[tf_gradient_placeholder], unconnected_gradients="zero" - ) - class _TensorFlowFunction(TensorFlowFunction): - inputs = tf_inputs - output = tf_output - gradient_placeholder = tf_gradient_placeholder - gradient_outputs = tf_gradient_outputs + def _torch_from_tensorflow(tf_session, tf_inputs, tf_output, tf_dtype=tf.float32): + tf_gradient_placeholder = tf.placeholder(dtype=tf_dtype, name=f"gradient") + tf_gradient_outputs = tf.gradients( + ys=tf_output, xs=tf_inputs, grad_ys=[tf_gradient_placeholder], unconnected_gradients="zero" + ) - @staticmethod - def forward(ctx, *args): - assert len(args) == len(tf_inputs) + class _TensorFlowFunction(TensorFlowFunction): + inputs = tf_inputs + output = tf_output + gradient_placeholder = tf_gradient_placeholder + gradient_outputs = tf_gradient_outputs - feed_dict = {tf_input: th_input.detach().numpy() for tf_input, th_input in zip(tf_inputs, args)} - output = tf_session.run(tf_output, feed_dict) + @staticmethod + def forward(ctx, *args): + assert len(args) == len(tf_inputs) - ctx.save_for_backward(*args) + feed_dict = {tf_input: th_input.detach().numpy() for tf_input, th_input in zip(tf_inputs, args)} + output = tf_session.run(tf_output, feed_dict) - th_output = th.as_tensor(output) - return th_output + ctx.save_for_backward(*args) - # See https://www.janfreyberg.com/blog/2019-04-01-testing-pytorch-functions/ for why "no cover" - @staticmethod - def backward(ctx, grad_output): # pragma: no cover - th_inputs = ctx.saved_tensors + th_output = th.as_tensor(output) + return th_output - feed_dict = {} - feed_dict.update({tf_input: th_input.detach().numpy() for tf_input, th_input in zip(tf_inputs, th_inputs)}) - feed_dict.update({tf_gradient_placeholder: grad_output.detach().numpy()}) + # See https://www.janfreyberg.com/blog/2019-04-01-testing-pytorch-functions/ for why "no cover" + @staticmethod + def backward(ctx, grad_output): # pragma: no cover + th_inputs = ctx.saved_tensors - tf_gradients = tf_session.run(tf_gradient_outputs, feed_dict) - return tuple(th.as_tensor(tf_gradient) for tf_gradient in tf_gradients) + feed_dict = {} + feed_dict.update( + {tf_input: th_input.detach().numpy() for tf_input, th_input in zip(tf_inputs, th_inputs)} + ) + feed_dict.update({tf_gradient_placeholder: grad_output.detach().numpy()}) - return _TensorFlowFunction() + tf_gradients = tf_session.run(tf_gradient_outputs, feed_dict) + return tuple(th.as_tensor(tf_gradient) for tf_gradient in tf_gradients) + + return _TensorFlowFunction() + + if isinstance(tf_outputs, list): + output_functions = [] + for tf_output in tf_outputs: + output_functions.append(_torch_from_tensorflow(tf_session, tf_inputs, tf_output, tf_dtype)) + return output_functions + else: + return _torch_from_tensorflow(tf_session, tf_inputs, tf_outputs, tf_dtype) def wrap_torch_from_tensorflow(func, tensor_inputs=None, input_shapes=None, session=None): @@ -89,6 +102,8 @@ def wrap_torch_from_tensorflow(func, tensor_inputs=None, input_shapes=None, sess if session is None: session = tf.compat.v1.Session() if tensor_inputs is None: + if isinstance(func, functools.partial): + func = func.func tensor_inputs = func.__code__.co_varnames[: func.__code__.co_argcount] if input_shapes is not None: From dfc35b31f34ab0a6ab6c326e9bfb0a8e613b9ebb Mon Sep 17 00:00:00 2001 From: Sandro Braun Date: Thu, 16 Apr 2020 11:38:28 +0200 Subject: [PATCH 5/8] adjusts wrap_torch_from_tensorflow for multiple outputs --- tests/test_adapters.py | 10 ++++++---- tfpyth/__init__.py | 16 +++++++++++++--- 2 files changed, 19 insertions(+), 7 deletions(-) diff --git a/tests/test_adapters.py b/tests/test_adapters.py index 47833a0..2d6e41d 100644 --- a/tests/test_adapters.py +++ b/tests/test_adapters.py @@ -177,13 +177,15 @@ def get_tf_function(a, b): f = tfpyth.wrap_torch_from_tensorflow(get_tf_function, ["a", "b"], None, session=session) a_ = th.tensor(1, dtype=th.float32, requires_grad=True) b_ = th.tensor(3, dtype=th.float32, requires_grad=True) - x = f(a_, b_) - - assert x == 39.0 + x1, x2 = f(a_, b_) - x.backward() + assert x1 == 39.0 + assert x2 == 78.0 + x1.backward() assert np.allclose((a_.grad, b_.grad), (3.0, 24.0)) + x2.backward() # partial derivatives are additive + assert np.allclose((a_.grad, b_.grad), (9.0, 72.0)) def test_autodetect_varnames(self): session = tf.compat.v1.Session() diff --git a/tfpyth/__init__.py b/tfpyth/__init__.py index fafb418..575cf98 100644 --- a/tfpyth/__init__.py +++ b/tfpyth/__init__.py @@ -97,7 +97,6 @@ def wrap_torch_from_tensorflow(func, tensor_inputs=None, input_shapes=None, sess Basically these values are fed to `tf.placeholder`, so you can indicate unknown parameters using `(None, 64, 64, 1)`, for instance. :param session: tf.compat.v1.Session A session. If None, will instantiate new session. - """ if session is None: session = tf.compat.v1.Session() @@ -116,8 +115,19 @@ def wrap_torch_from_tensorflow(func, tensor_inputs=None, input_shapes=None, sess } else: placeholders = {arg_name: tf.compat.v1.placeholder(tf.float32, name=arg_name) for arg_name in tensor_inputs} - output = func(**placeholders) - f = torch_from_tensorflow(session, [placeholders[t] for t in tensor_inputs], output).apply + outputs = func(**placeholders) + + if isinstance(outputs, tuple): + fs = [ + torch_from_tensorflow(session, [placeholders[t] for t in tensor_inputs], output).apply for output in outputs + ] + + def f(*args): + return [ff(*args) for ff in fs] + + else: + output = outputs + f = torch_from_tensorflow(session, [placeholders[t] for t in tensor_inputs], output).apply return f From 2002b81c83b346e6c3bcedcd9d790dde6685189e Mon Sep 17 00:00:00 2001 From: Sandro Braun Date: Thu, 16 Apr 2020 17:03:35 +0200 Subject: [PATCH 6/8] adds custom datatypes --- tests/test_adapters.py | 6 +++--- tfpyth/__init__.py | 21 ++++++++++++++++----- 2 files changed, 19 insertions(+), 8 deletions(-) diff --git a/tests/test_adapters.py b/tests/test_adapters.py index 2d6e41d..df75a91 100644 --- a/tests/test_adapters.py +++ b/tests/test_adapters.py @@ -120,7 +120,7 @@ def tensorflow_function(a, size=(128, 128)): session = tf.compat.v1.Session() tf_func = partial(tensorflow_function, size=(128, 128)) - f_pt = tfpyth.wrap_torch_from_tensorflow(tf_func, ["a"], [(None, 64, 64, 1)], session) + f_pt = tfpyth.wrap_torch_from_tensorflow(tf_func, ["a"], [(None, 64, 64, 1)], session=session) x = th.ones((1, 64, 64, 1), dtype=th.float32) y = f_pt(x) assert y.shape == (1, 128, 128, 1) @@ -133,7 +133,7 @@ def tensorflow_function(a, size=(128, 128)): session = tf.compat.v1.Session() tf_func = partial(tensorflow_function, size=(128, 128)) - f_pt = tfpyth.wrap_torch_from_tensorflow(tf_func, ["a"], [(None, 64, 64, 1)], session) + f_pt = tfpyth.wrap_torch_from_tensorflow(tf_func, ["a"], [(None, 64, 64, 1)], session=session) x = th.ones((1, 64, 64, 1), dtype=th.float32, requires_grad=False) conv = th.nn.Conv2d(1, 1, 1) x = conv(tfpyth.th_2D_channels_last_to_first(x)) @@ -184,7 +184,7 @@ def get_tf_function(a, b): x1.backward() assert np.allclose((a_.grad, b_.grad), (3.0, 24.0)) - x2.backward() # partial derivatives are additive + x2.backward() # partial derivatives are additive assert np.allclose((a_.grad, b_.grad), (9.0, 72.0)) def test_autodetect_varnames(self): diff --git a/tfpyth/__init__.py b/tfpyth/__init__.py index 575cf98..62c8eca 100644 --- a/tfpyth/__init__.py +++ b/tfpyth/__init__.py @@ -82,7 +82,7 @@ def backward(ctx, grad_output): # pragma: no cover return _torch_from_tensorflow(tf_session, tf_inputs, tf_outputs, tf_dtype) -def wrap_torch_from_tensorflow(func, tensor_inputs=None, input_shapes=None, session=None): +def wrap_torch_from_tensorflow(func, tensor_inputs=None, input_shapes=None, input_dtypes=None, session=None): """wrap func using `torch_from_tensorflow` and automatically create placeholders. By default, placeholders are assumed to be `tf.float32`. @@ -95,6 +95,8 @@ def wrap_torch_from_tensorflow(func, tensor_inputs=None, input_shapes=None, sess :param input_shapes: List[Tuple[Int]]. Shapes of input tensors if known. Some operations require these, such as all `tf.image.resize`. Basically these values are fed to `tf.placeholder`, so you can indicate unknown parameters using `(None, 64, 64, 1)`, for instance. + :param input_dtypes: List[tf.dtype]. + Data types to associate inputs with. By default, will treat all inputs as `tf.float32` :param session: tf.compat.v1.Session A session. If None, will instantiate new session. """ @@ -109,10 +111,19 @@ def wrap_torch_from_tensorflow(func, tensor_inputs=None, input_shapes=None, sess if len(tensor_inputs) != len(input_shapes): raise ValueError("Number of tensor inputs does not match number of input shapes") else: - placeholders = { - arg_name: tf.compat.v1.placeholder(tf.float32, shape=shape, name=arg_name) - for arg_name, shape in zip(tensor_inputs, input_shapes) - } + if input_dtypes is not None: + if len(input_dtypes) != len(input_shapes): + raise ValueError("Number of tensor input dtypes does not match number of input shapes") + else: + placeholders = { + arg_name: tf.compat.v1.placeholder(shape=shape, dtype=dtype, name=arg_name) + for arg_name, shape, dtype in zip(tensor_inputs, input_shapes, input_dtypes) + } + else: + placeholders = { + arg_name: tf.compat.v1.placeholder(tf.float32, shape=shape, name=arg_name) + for arg_name, shape in zip(tensor_inputs, input_shapes) + } else: placeholders = {arg_name: tf.compat.v1.placeholder(tf.float32, name=arg_name) for arg_name in tensor_inputs} outputs = func(**placeholders) From 0e64e79b1b2d52ab2d70a875c93c646271cd1ffa Mon Sep 17 00:00:00 2001 From: Sandro Braun Date: Tue, 28 Apr 2020 13:41:55 +0200 Subject: [PATCH 7/8] adds session singleton --- tfpyth/__init__.py | 23 +++++++++++++++++++++-- 1 file changed, 21 insertions(+), 2 deletions(-) diff --git a/tfpyth/__init__.py b/tfpyth/__init__.py index 62c8eca..720c21f 100644 --- a/tfpyth/__init__.py +++ b/tfpyth/__init__.py @@ -3,6 +3,17 @@ import functools +class SingleSession: + instance = None + """https://python-3-patterns-idioms-test.readthedocs.io/en/latest/Singleton.html""" + + def __init__(self): + if not SingleSession.instance: + SingleSession.instance = tf.compat.v1.Session() + + def get_session(self): + return SingleSession.instance + class TensorFlowFunction(th.autograd.Function): """ Wrapper class for Tensorflow input/output nodes (incl gradient) in PyTorch. @@ -49,7 +60,15 @@ class _TensorFlowFunction(TensorFlowFunction): def forward(ctx, *args): assert len(args) == len(tf_inputs) - feed_dict = {tf_input: th_input.detach().numpy() for tf_input, th_input in zip(tf_inputs, args)} + feed_dict = {} + for tf_input, th_input in zip(tf_inputs, args): + if th_input.is_cuda: + feed_dict[tf_input] = th_input.cpu().detach().numpy() + else: + feed_dict[tf_input] = th_input.detach().numpy() + + # TODO: write test for cuda tensors + # feed_dict = {tf_input: th_input.detach().numpy() for tf_input, th_input in zip(tf_inputs, args)} output = tf_session.run(tf_output, feed_dict) ctx.save_for_backward(*args) @@ -101,7 +120,7 @@ def wrap_torch_from_tensorflow(func, tensor_inputs=None, input_shapes=None, inpu A session. If None, will instantiate new session. """ if session is None: - session = tf.compat.v1.Session() + session = SingleSession().get_session() if tensor_inputs is None: if isinstance(func, functools.partial): func = func.func From ec0f60fe92a5d43ea401a6f59447497a2c4bc19a Mon Sep 17 00:00:00 2001 From: Sandro Braun Date: Tue, 28 Apr 2020 16:42:22 +0200 Subject: [PATCH 8/8] adds notes on session --- README.md | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/README.md b/README.md index 9a650bb..8fc39d6 100644 --- a/README.md +++ b/README.md @@ -102,6 +102,20 @@ Creates an eager Tensorflow function from a PyTorch function. Creates a TensorFlow op/tensor from a PyTorch function. +## Notes on session management + + +* when using `wrap_torch_from_tensorflow` without `session` argument, a (singleton) session will be created in the background and used for every call to `wrap_torch_from_tensorflow`. +* one can access this session using + +```python +import tfpyth + +session = tfpyth.SingleSession.get_session() + +``` + + ## Future work