-
Notifications
You must be signed in to change notification settings - Fork 231
Add the TensorFlow version of some JAX utilities #59
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Look great! Left a few minor comments.
tf_helpers/tf_dot_general.py
Outdated
Construct an equivalent general dot operation as that in JAX - | ||
<https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.dot_general.html> | ||
|
||
Although there is an implementation in TF XLA, avoid directly using XLA when |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just curious, why?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh yeah, great question! I think Ashish mentioned that we want the newly constructed APIs - general conv, etc to be able to inter-op with the rest of the TF ecosystem, but I guess TF XLA is too low-level and independent. @wangpengmit Would you mind verifying this?
Tests coming soon. |
tf_hlpers
TensorFlow-related ecosystem and the tests folder contain both the `lax` and its test cases.
I will also add |
Probably not going to add test cases for |
the TF pool shape checker in order to make the TF `reduce_window` API consistent with JAX `reduce_window`.
stax back to TF stax.
Mark UPDATE (2nd September, 2020): Finished in #63. |
As titled, this Pull Request contains some necessary TF-based helper APIs for Neural Tangents, and they will be served as the main support.