Skip to content
This repository was archived by the owner on May 6, 2025. It is now read-only.

Commit a240b24

Browse files
authored
Merge pull request #59 from DarrenZhang01/master
Add the TensorFlow version of some JAX utilities
2 parents 700d236 + ab17374 commit a240b24

File tree

4 files changed

+904
-0
lines changed

4 files changed

+904
-0
lines changed

tests/lax_test.py

+102
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
16+
17+
import tensorflow as tf
18+
from tf_helpers import lax
19+
from tensorflow.python.platform import test
20+
from absl.testing import parameterized
21+
import itertools
22+
import numpy as onp
23+
from tensorflow.python.ops import numpy_ops as tfnp
24+
from jax import numpy as jnp
25+
import jax
26+
import sys
27+
28+
29+
class TFLaxTest(tf.test.TestCase, parameterized.TestCase):
30+
31+
@parameterized.parameters(
32+
{"lhs_np": onp.ones((5, 3)), "rhs_np": onp.ones((3, 2)),
33+
"dims": (((1,), (0,)), ((), ()))},
34+
{"lhs_np": onp.ones((5, 3)), "rhs_np": onp.ones((5, 3)),
35+
"dims": (((0, 1), (0, 1)), ((), ()))},
36+
{"lhs_np": onp.ones((5, 3, 2)), "rhs_np": onp.ones((2, 3, 2)),
37+
"dims": (((1, 2), (1, 0)), ((), ()))},
38+
{"lhs_np": onp.ones((6, 5, 3)), "rhs_np": onp.ones((6, 3, 2)),
39+
"dims": (((2,), (1,)), ((0,), (0,)))},
40+
{"lhs_np": onp.ones((6, 3, 5)), "rhs_np": onp.ones((6, 3, 2)),
41+
"dims": (((1,), (1,)), ((0,), (0,)))},
42+
{"lhs_np": onp.ones((5, 3, 2, 2)), "rhs_np": onp.ones((5, 2, 2, 6)),
43+
"dims": (((2, 3), (1, 2)), ((0,), (0,)))},
44+
{"lhs_np": onp.ones((2, 2, 5, 3)), "rhs_np": onp.ones((2, 2, 3, 2)),
45+
"dims": (((3,), (2,)), ((0, 1), (0, 1)))},
46+
{"lhs_np": onp.ones((2, 2, 5, 2)), "rhs_np": onp.ones((2, 2, 3, 2)),
47+
"dims": (((3,), (1,)), ((0,), (0,)))},
48+
{"lhs_np": onp.ones((2, 2, 5, 3, 3)), "rhs_np": onp.ones((2, 3, 2, 3, 2)),
49+
"dims": (((4,), (1,)), ((0,), (0,)))},
50+
)
51+
def test_tf_dot_general(self, lhs_np, rhs_np, dims):
52+
ans = jax.lax.dot_general(lhs_np, rhs_np, dims)
53+
result = lax.dot_general(lhs_np, rhs_np, dims)
54+
self.assertAllClose(result, tfnp.array(ans))
55+
56+
@parameterized.named_parameters([
57+
("_lhs_shape={}_rhs_shape={}_strides={}_padding={}"
58+
"_lhs_dilation={}_rhs_dilation={}"
59+
"_feature_group_count={}_batch_group_count={}_dims={}"
60+
"_perms={}".format(lhs_shape, rhs_shape,
61+
strides, padding, lhs_dilation, rhs_dilation,
62+
feature_group_count, batch_group_count, ",".join(dimension_numbers), perms),
63+
lhs_shape, rhs_shape, strides, padding, lhs_dilation, rhs_dilation,
64+
feature_group_count, batch_group_count, dimension_numbers, perms)
65+
for batch_group_count, feature_group_count in [(1, 1)]
66+
for lhs_shape, rhs_shape in [
67+
((b * batch_group_count, i * feature_group_count, 9, w),
68+
(j * feature_group_count * batch_group_count, i, 4, 5))
69+
for w in [0, 10]
70+
for b, i, j in itertools.product([2, 3], repeat=3)]
71+
for strides in [(1, 1), (2, 1)]
72+
for padding in ['SAME']
73+
for lhs_dilation, rhs_dilation in [
74+
(None, (1, 1))
75+
]
76+
for dimension_numbers, perms in [
77+
(("NHWC", "HWIO", "NHWC"), ([0, 2, 3, 1], [2, 3, 1, 0]))
78+
]])
79+
def testConvGeneralDilated(self, lhs_shape, rhs_shape, strides,
80+
padding, lhs_dilation, rhs_dilation,
81+
feature_group_count, batch_group_count,
82+
dimension_numbers, perms):
83+
tf.print("dimension_numbers: {}".format(dimension_numbers), output_stream=sys.stdout)
84+
lhs_perm, rhs_perm = perms # permute to compatible shapes
85+
86+
lhs_tf = tfnp.transpose(tfnp.ones(lhs_shape), lhs_perm)
87+
rhs_tf = tfnp.transpose(tfnp.ones(rhs_shape), rhs_perm)
88+
89+
lhs_jax = jnp.transpose(jnp.ones(lhs_shape), lhs_perm)
90+
rhs_jax = jnp.transpose(jnp.ones(rhs_shape), rhs_perm)
91+
92+
jax_conv = jax.lax.conv_general_dilated(lhs_jax, rhs_jax, strides, padding, lhs_dilation,
93+
rhs_dilation, dimension_numbers, feature_group_count, batch_group_count)
94+
95+
tf_conv = lax.conv_general_dilated(lhs_tf, rhs_tf, strides, padding, jax_conv.shape, lhs_dilation,
96+
rhs_dilation, dimension_numbers, feature_group_count, batch_group_count)
97+
98+
self.assertAllEqual(tf_conv, tfnp.asarray(jax_conv))
99+
100+
101+
if __name__ == "__main__":
102+
test.main()

tf_helpers/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)