Skip to content
This repository was archived by the owner on May 6, 2025. It is now read-only.

Make the test of the infinte_fcn example file pass. #63

Open
wants to merge 10 commits into
base: neural-tangents-tf
Choose a base branch
from
1 change: 1 addition & 0 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,4 @@ script:
# errors in stax tests and infinite fcn tests.
- python tests/function_space_test.py
- python tests/weight_space_test.py
- python tests/infinite_fcn_test.py
7 changes: 3 additions & 4 deletions examples/infinite_fcn.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,14 @@
import time
from absl import app
from absl import flags
import jax.numpy as np
import neural_tangents as nt
from neural_tangents import stax
from examples import datasets
from examples import util

import tensorflow as tf
from tensorflow.python.ops import numpy_ops as np


flags.DEFINE_integer('train_size', 1000,
'Dataset size to use for training.')
Expand All @@ -50,7 +52,6 @@ def main(unused_argv):
stax.Relu(),
stax.Dense(1, 2., 0.05)
)

# Optionally, compute the kernel in batches, in parallel.
kernel_fn = nt.batch(kernel_fn,
device_count=0,
Expand All @@ -61,8 +62,6 @@ def main(unused_argv):
predict_fn = nt.predict.gradient_descent_mse_ensemble(kernel_fn, x_train,
y_train, diag_reg=1e-3)
fx_test_nngp, fx_test_ntk = predict_fn(x_test=x_test)
fx_test_nngp.block_until_ready()
fx_test_ntk.block_until_ready()

duration = time.time() - start
print('Kernel construction and inference done in %s seconds.' % duration)
Expand Down
3 changes: 1 addition & 2 deletions neural_tangents/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -1007,7 +1007,6 @@ def _get_fns_in_eigenbasis(k_train_train: np.ndarray,
k_train_train = utils.make_2d(k_train_train)
k_train_train = _add_diagonal_regularizer(k_train_train, diag_reg,
diag_reg_absolute_scale)
print("k_train_train: {}".format(k_train_train))
evals, evecs = tf.linalg.eigh(k_train_train)
evals, evecs = np.asarray(evals), np.asarray(evecs)

Expand Down Expand Up @@ -1050,7 +1049,7 @@ def cho_solve(b: np.ndarray, b_axes: Axes) -> np.ndarray:
b = np.moveaxis(b, b_axes, last_b_axes)
b = b.reshape((A.shape[1], -1))

x = np.linalg_ops.cholesky_solve(C, b)
x = np.asarray(np.linalg_ops.cholesky_solve(C, b))
x = x.reshape(x_shape)
return x

Expand Down
7 changes: 3 additions & 4 deletions neural_tangents/stax.py
Original file line number Diff line number Diff line change
Expand Up @@ -2534,7 +2534,6 @@ def get_x_cov_mask(x):
x = x.astype(np.float64)

if diagonal_batch:
print("diagonal_spatial: {}".format(diagonal_spatial))
cov = _cov_diag_batch(x, diagonal_spatial, batch_axis, channel_axis)
else:
cov = _cov(x, x, diagonal_spatial, batch_axis, channel_axis)
Expand Down Expand Up @@ -2571,11 +2570,11 @@ def get_x_cov_mask(x):

def _propagate_shape(init_fn: InitFn, shape: Shapes) -> Shapes:
"""Statically, abstractly, evaluate the init_fn to get shape information."""
akey = tf.TensorSpec((2,), np.uint32)
akey = tf.TensorSpec((2,), np.int32)
closed_init_fn = functools.partial(init_fn, input_shape=shape)
_, in_tree = tree_flatten(((akey,), {}))
fun, out_tree = flatten_fun(lu.wrap_init(closed_init_fn), in_tree)
out = eval_on_shapes(fun.call_wrapped)(akey)
out = eval_on_shapes(fun.call_wrapped, allow_static_outputs=True)(akey)
out_shape = tree_unflatten(out_tree(), out)[0]
return out_shape

Expand Down Expand Up @@ -2826,7 +2825,7 @@ def _get_diagonal(

batch_ndim = 1 if diagonal_batch else 2
start_axis = 2 - batch_ndim
end_axis = batch_ndim if diagonal_spatial else cov.ndim
end_axis = batch_ndim if diagonal_spatial else len(cov.shape)
cov = utils.unzip_axes(cov, start_axis, end_axis)
return utils.diagonal_between(cov, start_axis, end_axis)

Expand Down
1 change: 0 additions & 1 deletion neural_tangents/utils/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,7 +416,6 @@ def serial_fn(x1_or_kernel: Union[np.ndarray, Kernel],
if isinstance(x1_or_kernel, np.ndarray):
return serial_fn_x1(x1_or_kernel, x2, *args, **kwargs)
elif isinstance(x1_or_kernel, onp.ndarray):
print("args: , kwargs: ", *args, **kwargs)
return serial_fn_x1(np.asarray(x1_or_kernel), x2, *args, **kwargs)
elif isinstance(x1_or_kernel, Kernel):
if x2 is not None:
Expand Down
4 changes: 0 additions & 4 deletions neural_tangents/utils/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,10 +77,6 @@ def clz_from_iterable(meta, data):
kwargs = dict(meta_args + data_args)
return data_clz(**kwargs)

jax.tree_util.register_pytree_node(data_clz,
iterate_clz,
clz_from_iterable)

def replace(self: data_clz, **kwargs) -> data_clz:
return dataclasses.replace(self, **kwargs)

Expand Down
5 changes: 2 additions & 3 deletions neural_tangents/utils/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@

import operator as op
from typing import Dict, Tuple, Optional, Callable, Any
import dataclasses
from neural_tangents.utils import utils
from neural_tangents.utils import utils, dataclasses

import tensorflow as tf
from tensorflow.python.ops import numpy_ops as np
Expand Down Expand Up @@ -124,7 +123,7 @@ class Kernel:
shape2: Tuple[int, ...]

batch_axis: int
channel_axis: int
channel_axis: int

mask1: Optional[np.ndarray] = None
mask2: Optional[np.ndarray] = None
Expand Down
8 changes: 4 additions & 4 deletions tests/lax_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,10 +76,10 @@ def test_tf_dot_general(self, lhs_np, rhs_np, dims):
for dimension_numbers, perms in [
(("NHWC", "HWIO", "NHWC"), ([0, 2, 3, 1], [2, 3, 1, 0]))
]])
def testConvGeneralDilated(self, lhs_shape, rhs_shape, strides,
padding, lhs_dilation, rhs_dilation,
feature_group_count, batch_group_count,
dimension_numbers, perms):
def test_tf_conv_general_dilated(self, lhs_shape, rhs_shape, strides,
padding, lhs_dilation, rhs_dilation,
feature_group_count, batch_group_count,
dimension_numbers, perms):
tf.print("dimension_numbers: {}".format(dimension_numbers), output_stream=sys.stdout)
lhs_perm, rhs_perm = perms # permute to compatible shapes

Expand Down
Loading