@@ -940,7 +940,6 @@ def test_activation(self, same_inputs, model, phi_name, get, abc):
940
940
for get in ['nngp' , 'ntk' ]
941
941
for gamma in [1e-6 , 1e-4 , 1e-2 , 1.0 , 2. ]
942
942
))
943
-
944
943
def test_rbf (self , same_inputs , model , get , gamma ):
945
944
activation = stax .Rbf (gamma )
946
945
self ._test_activation (activation , same_inputs , model , get ,
@@ -2138,7 +2137,7 @@ def get_attn():
2138
2137
test_utils .assert_close_matrices (self , empirical , exact , tol )
2139
2138
2140
2139
2141
- class GNTKTest (test_utils .NeuralTangentsTestCase ):
2140
+ class AggregateTest (test_utils .NeuralTangentsTestCase ):
2142
2141
@jtu .parameterized .named_parameters (
2143
2142
jtu .cases_from_list ({
2144
2143
'testcase_name' :
@@ -2157,9 +2156,9 @@ class GNTKTest(test_utils.NeuralTangentsTestCase):
2157
2156
for test_mask in [True ]
2158
2157
))
2159
2158
2160
- def test_GNTK (self , get , readout , same_input , activation , test_mask ):
2159
+ def test_aggregate (self , get , readout , same_input , activation , test_mask ):
2161
2160
batch1 , batch2 = 8 , 6
2162
- num_nodes , num_channels = 8 , 12
2161
+ num_nodes , num_channels = 4 , 2
2163
2162
output_dims = 1 if get == 'ntk' else 1024
2164
2163
key = random .PRNGKey (1 )
2165
2164
key , split1 , split2 = random .split (key , 3 )
@@ -2183,15 +2182,16 @@ def test_GNTK(self, get, readout, same_input, activation, test_mask):
2183
2182
2184
2183
# Build the infinite network.
2185
2184
init_fn , apply_fn , kernel_fn = stax .serial (
2186
- stax .Dense (128 * 8 * 4 ),
2185
+ stax .Dense (128 * 8 ),
2187
2186
activation ,
2188
2187
stax .Dropout (0.5 , mode = 'train' ),
2189
2188
stax .Aggregate (),
2190
2189
readout ,
2191
2190
stax .Dense (output_dims ))
2192
2191
kernel_fn = batch .batch (kernel_fn , batch_size = 2 )
2193
2192
kernel_mc_fn = monte_carlo .monte_carlo_kernel_fn (
2194
- init_fn , apply_fn , random .PRNGKey (10 ), 300 )
2193
+ init_fn , apply_fn , random .PRNGKey (10 ), 128 ,
2194
+ batch_size = 2 if xla_bridge .get_backend ().platform == 'tpu' else 0 )
2195
2195
empirical = kernel_mc_fn (x1 , x2 , get ,
2196
2196
mask_constant = mask_constant if test_mask else None ,
2197
2197
pattern = (pattern1 , pattern2 ))
0 commit comments