Skip to content
This repository was archived by the owner on May 6, 2025. It is now read-only.

Commit 700d236

Browse files
committed
Squeeze tests
PiperOrigin-RevId: 327857859
1 parent de2a61d commit 700d236

File tree

2 files changed

+6
-8
lines changed

2 files changed

+6
-8
lines changed

tests/monte_carlo_test.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,6 @@
3838

3939
STORE_ON_DEVICE = [True, False]
4040

41-
N_SAMPLES = 4
42-
4341
ALL_GET = ('nngp', 'ntk', ('nngp', 'ntk'), None)
4442

4543
test_utils.update_test_tolerance()

tests/stax_test.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -940,7 +940,6 @@ def test_activation(self, same_inputs, model, phi_name, get, abc):
940940
for get in ['nngp', 'ntk']
941941
for gamma in [1e-6, 1e-4, 1e-2, 1.0, 2.]
942942
))
943-
944943
def test_rbf(self, same_inputs, model, get, gamma):
945944
activation = stax.Rbf(gamma)
946945
self._test_activation(activation, same_inputs, model, get,
@@ -2138,7 +2137,7 @@ def get_attn():
21382137
test_utils.assert_close_matrices(self, empirical, exact, tol)
21392138

21402139

2141-
class GNTKTest(test_utils.NeuralTangentsTestCase):
2140+
class AggregateTest(test_utils.NeuralTangentsTestCase):
21422141
@jtu.parameterized.named_parameters(
21432142
jtu.cases_from_list({
21442143
'testcase_name':
@@ -2157,9 +2156,9 @@ class GNTKTest(test_utils.NeuralTangentsTestCase):
21572156
for test_mask in [True]
21582157
))
21592158

2160-
def test_GNTK(self, get, readout, same_input, activation, test_mask):
2159+
def test_aggregate(self, get, readout, same_input, activation, test_mask):
21612160
batch1, batch2 = 8, 6
2162-
num_nodes, num_channels = 8, 12
2161+
num_nodes, num_channels = 4, 2
21632162
output_dims = 1 if get == 'ntk' else 1024
21642163
key = random.PRNGKey(1)
21652164
key, split1, split2 = random.split(key, 3)
@@ -2183,15 +2182,16 @@ def test_GNTK(self, get, readout, same_input, activation, test_mask):
21832182

21842183
# Build the infinite network.
21852184
init_fn, apply_fn, kernel_fn = stax.serial(
2186-
stax.Dense(128*8*4),
2185+
stax.Dense(128*8),
21872186
activation,
21882187
stax.Dropout(0.5, mode='train'),
21892188
stax.Aggregate(),
21902189
readout,
21912190
stax.Dense(output_dims))
21922191
kernel_fn = batch.batch(kernel_fn, batch_size=2)
21932192
kernel_mc_fn = monte_carlo.monte_carlo_kernel_fn(
2194-
init_fn, apply_fn, random.PRNGKey(10), 300)
2193+
init_fn, apply_fn, random.PRNGKey(10), 128,
2194+
batch_size=2 if xla_bridge.get_backend().platform == 'tpu' else 0)
21952195
empirical = kernel_mc_fn(x1, x2, get,
21962196
mask_constant=mask_constant if test_mask else None,
21972197
pattern=(pattern1, pattern2))

0 commit comments

Comments
 (0)