Skip to content
This repository was archived by the owner on May 6, 2025. It is now read-only.

Add the TensorFlow version of some JAX utilities #59

Merged
merged 28 commits into from
Aug 26, 2020

Conversation

DarrenZhang01
Copy link
Contributor

@DarrenZhang01 DarrenZhang01 commented Aug 24, 2020

As titled, this Pull Request contains some necessary TF-based helper APIs for Neural Tangents, and they will be served as the main support.

Copy link
Contributor

@romanngg romanngg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Look great! Left a few minor comments.

Construct an equivalent general dot operation as that in JAX -
<https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.dot_general.html>

Although there is an implementation in TF XLA, avoid directly using XLA when
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just curious, why?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh yeah, great question! I think Ashish mentioned that we want the newly constructed APIs - general conv, etc to be able to inter-op with the rest of the TF ecosystem, but I guess TF XLA is too low-level and independent. @wangpengmit Would you mind verifying this?

@DarrenZhang01
Copy link
Contributor Author

Tests coming soon.

@DarrenZhang01 DarrenZhang01 changed the title Add the TensorFlow version of general dot operation to tf_hlpers Add the TensorFlow version of some JAX lax utilities Aug 25, 2020
@DarrenZhang01
Copy link
Contributor Author

I will also add ostax into this pull request shortly.

@DarrenZhang01
Copy link
Contributor Author

Probably not going to add test cases for ostax since there are some dependencies that are still not there.

@DarrenZhang01 DarrenZhang01 requested a review from romanngg August 26, 2020 00:49
@DarrenZhang01 DarrenZhang01 changed the title Add the TensorFlow version of some JAX lax utilities Add the TensorFlow version of some JAX utilities Aug 26, 2020
DarrenZhang01 added 2 commits August 26, 2020 16:16
the TF pool shape checker in order to make the TF `reduce_window` API
consistent with JAX `reduce_window`.
@DarrenZhang01
Copy link
Contributor Author

DarrenZhang01 commented Aug 26, 2020

@DarrenZhang01 DarrenZhang01 requested a review from romanngg August 26, 2020 20:57
@romanngg romanngg merged commit a240b24 into google:neural-tangents-tf Aug 26, 2020
DarrenZhang01 pushed a commit to DarrenZhang01/neural-tangents that referenced this pull request Sep 2, 2020
DarrenZhang01 pushed a commit to DarrenZhang01/neural-tangents that referenced this pull request Sep 2, 2020
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants