Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simpler interface that automatically creates placeholders #3

Open
theRealSuperMario opened this issue Apr 15, 2020 · 2 comments
Open

Comments

@theRealSuperMario
Copy link

theRealSuperMario commented Apr 15, 2020

Hi. First of all, thanks for writing this library. It saves me a lot of time just wrapping tensorflow code into a new function and using it as is in my pytorch projects.

This is exactly the gist of this issue: I thought we might be able to get the interface a little bit simpler so that you just literally have to wrap it into a new function.

I think it should look something like this:

def tf_function(a, b):
    c = 3 * a + 4 * b * b

    return c

session = tf.compat.v1.Session()
f = tfpyth.wrap_torch_from_tensorflow(
        tf_function, ["a", "b"], session=session
    )
# or simpler
f = tfpyth.wrap_torch_from_tensorflow(
        tf_function, session=session
    ) # automatically creates placeholders for "a" and "b" inside
# or even simpler
f = tfpyth.wrap_torch_from_tensorflow(
        tf_function
    ) # automatically creates placeholders for "a" and "b" and session

a_ = th.tensor(1, dtype=th.float32, requires_grad=True)
b_ = th.tensor(3, dtype=th.float32, requires_grad=True)
x = f(a_, b_)

assert x == 39.0

x.backward()

assert np.allclose((a_.grad, b_.grad), (3.0, 24.0))

Turns out, I already went ahead and added this exact feature. You can check it out -->here<--. I was just wondering a PR would be interesting to persue this feature.

Cheers

@BlackHC
Copy link
Owner

BlackHC commented Jun 8, 2020

Oh yes, this is cool! 🎉 I'm currently scramming towards some deadlines, but I'll get back to you next week.

It looks very useful. I need to think about the session singleton. I'm wary of having too much state, so it might be easier to just use tf.compat.v1.get_default_session()? Everything else could go in right away I think.

Thanks!

@theRealSuperMario
Copy link
Author

theRealSuperMario commented Jun 9, 2020 via email

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants