diff --git a/.travis.yml b/.travis.yml index 9ed883fc..451e5463 100644 --- a/.travis.yml +++ b/.travis.yml @@ -17,3 +17,4 @@ script: # errors in stax tests and infinite fcn tests. - python tests/function_space_test.py - python tests/weight_space_test.py + - python tests/infinite_fcn_test.py diff --git a/examples/infinite_fcn.py b/examples/infinite_fcn.py index ffe327f1..25444d20 100644 --- a/examples/infinite_fcn.py +++ b/examples/infinite_fcn.py @@ -20,12 +20,14 @@ import time from absl import app from absl import flags -import jax.numpy as np import neural_tangents as nt from neural_tangents import stax from examples import datasets from examples import util +import tensorflow as tf +from tensorflow.python.ops import numpy_ops as np + flags.DEFINE_integer('train_size', 1000, 'Dataset size to use for training.') @@ -50,7 +52,6 @@ def main(unused_argv): stax.Relu(), stax.Dense(1, 2., 0.05) ) - # Optionally, compute the kernel in batches, in parallel. kernel_fn = nt.batch(kernel_fn, device_count=0, @@ -61,8 +62,6 @@ def main(unused_argv): predict_fn = nt.predict.gradient_descent_mse_ensemble(kernel_fn, x_train, y_train, diag_reg=1e-3) fx_test_nngp, fx_test_ntk = predict_fn(x_test=x_test) - fx_test_nngp.block_until_ready() - fx_test_ntk.block_until_ready() duration = time.time() - start print('Kernel construction and inference done in %s seconds.' % duration) diff --git a/neural_tangents/predict.py b/neural_tangents/predict.py index 1c0ac8d2..be3da6a8 100644 --- a/neural_tangents/predict.py +++ b/neural_tangents/predict.py @@ -1007,7 +1007,6 @@ def _get_fns_in_eigenbasis(k_train_train: np.ndarray, k_train_train = utils.make_2d(k_train_train) k_train_train = _add_diagonal_regularizer(k_train_train, diag_reg, diag_reg_absolute_scale) - print("k_train_train: {}".format(k_train_train)) evals, evecs = tf.linalg.eigh(k_train_train) evals, evecs = np.asarray(evals), np.asarray(evecs) @@ -1050,7 +1049,7 @@ def cho_solve(b: np.ndarray, b_axes: Axes) -> np.ndarray: b = np.moveaxis(b, b_axes, last_b_axes) b = b.reshape((A.shape[1], -1)) - x = np.linalg_ops.cholesky_solve(C, b) + x = np.asarray(np.linalg_ops.cholesky_solve(C, b)) x = x.reshape(x_shape) return x diff --git a/neural_tangents/stax.py b/neural_tangents/stax.py index e97a9b03..66c0a517 100644 --- a/neural_tangents/stax.py +++ b/neural_tangents/stax.py @@ -2534,7 +2534,6 @@ def get_x_cov_mask(x): x = x.astype(np.float64) if diagonal_batch: - print("diagonal_spatial: {}".format(diagonal_spatial)) cov = _cov_diag_batch(x, diagonal_spatial, batch_axis, channel_axis) else: cov = _cov(x, x, diagonal_spatial, batch_axis, channel_axis) @@ -2571,11 +2570,11 @@ def get_x_cov_mask(x): def _propagate_shape(init_fn: InitFn, shape: Shapes) -> Shapes: """Statically, abstractly, evaluate the init_fn to get shape information.""" - akey = tf.TensorSpec((2,), np.uint32) + akey = tf.TensorSpec((2,), np.int32) closed_init_fn = functools.partial(init_fn, input_shape=shape) _, in_tree = tree_flatten(((akey,), {})) fun, out_tree = flatten_fun(lu.wrap_init(closed_init_fn), in_tree) - out = eval_on_shapes(fun.call_wrapped)(akey) + out = eval_on_shapes(fun.call_wrapped, allow_static_outputs=True)(akey) out_shape = tree_unflatten(out_tree(), out)[0] return out_shape @@ -2826,7 +2825,7 @@ def _get_diagonal( batch_ndim = 1 if diagonal_batch else 2 start_axis = 2 - batch_ndim - end_axis = batch_ndim if diagonal_spatial else cov.ndim + end_axis = batch_ndim if diagonal_spatial else len(cov.shape) cov = utils.unzip_axes(cov, start_axis, end_axis) return utils.diagonal_between(cov, start_axis, end_axis) diff --git a/neural_tangents/utils/batch.py b/neural_tangents/utils/batch.py index bfbff58b..a764acd8 100644 --- a/neural_tangents/utils/batch.py +++ b/neural_tangents/utils/batch.py @@ -416,7 +416,6 @@ def serial_fn(x1_or_kernel: Union[np.ndarray, Kernel], if isinstance(x1_or_kernel, np.ndarray): return serial_fn_x1(x1_or_kernel, x2, *args, **kwargs) elif isinstance(x1_or_kernel, onp.ndarray): - print("args: , kwargs: ", *args, **kwargs) return serial_fn_x1(np.asarray(x1_or_kernel), x2, *args, **kwargs) elif isinstance(x1_or_kernel, Kernel): if x2 is not None: diff --git a/neural_tangents/utils/dataclasses.py b/neural_tangents/utils/dataclasses.py index 752fa154..8bf0fe89 100644 --- a/neural_tangents/utils/dataclasses.py +++ b/neural_tangents/utils/dataclasses.py @@ -77,10 +77,6 @@ def clz_from_iterable(meta, data): kwargs = dict(meta_args + data_args) return data_clz(**kwargs) - jax.tree_util.register_pytree_node(data_clz, - iterate_clz, - clz_from_iterable) - def replace(self: data_clz, **kwargs) -> data_clz: return dataclasses.replace(self, **kwargs) diff --git a/neural_tangents/utils/kernel.py b/neural_tangents/utils/kernel.py index 22e76c10..db2bac73 100644 --- a/neural_tangents/utils/kernel.py +++ b/neural_tangents/utils/kernel.py @@ -17,8 +17,7 @@ import operator as op from typing import Dict, Tuple, Optional, Callable, Any -import dataclasses -from neural_tangents.utils import utils +from neural_tangents.utils import utils, dataclasses import tensorflow as tf from tensorflow.python.ops import numpy_ops as np @@ -124,7 +123,7 @@ class Kernel: shape2: Tuple[int, ...] batch_axis: int - channel_axis: int + channel_axis: int mask1: Optional[np.ndarray] = None mask2: Optional[np.ndarray] = None diff --git a/tests/lax_test.py b/tests/lax_test.py index f4501782..a573db3d 100644 --- a/tests/lax_test.py +++ b/tests/lax_test.py @@ -76,10 +76,10 @@ def test_tf_dot_general(self, lhs_np, rhs_np, dims): for dimension_numbers, perms in [ (("NHWC", "HWIO", "NHWC"), ([0, 2, 3, 1], [2, 3, 1, 0])) ]]) - def testConvGeneralDilated(self, lhs_shape, rhs_shape, strides, - padding, lhs_dilation, rhs_dilation, - feature_group_count, batch_group_count, - dimension_numbers, perms): + def test_tf_conv_general_dilated(self, lhs_shape, rhs_shape, strides, + padding, lhs_dilation, rhs_dilation, + feature_group_count, batch_group_count, + dimension_numbers, perms): tf.print("dimension_numbers: {}".format(dimension_numbers), output_stream=sys.stdout) lhs_perm, rhs_perm = perms # permute to compatible shapes diff --git a/tf_helpers/extensions.py b/tf_helpers/extensions.py index b875b509..9c09e2a4 100644 --- a/tf_helpers/extensions.py +++ b/tf_helpers/extensions.py @@ -22,6 +22,7 @@ import bisect import contextlib import string +import sys import threading import numpy as np import six @@ -37,6 +38,9 @@ tf.int64, tf.int32, tf.int16, tf.int8, tf.uint8, tf.uint16, tf.uint32, tf.uint64 ] +_tf_nn_APIs = {1: [tf.nn.conv1d, tf.nn.conv1d_transpose], + 2: [tf.nn.conv2d, tf.nn.conv2d_transpose], + 3: [tf.nn.conv3d, tf.nn.conv3d_transpose]} def most_precise_int_dtype(x): @@ -272,7 +276,7 @@ def _f(params, *args): def _record_result_type(recorder, f): def wrapper(*args, **kwargs): res = f(*args, **kwargs) - recorder(res) + res = recorder(res) return res return wrapper @@ -332,6 +336,7 @@ def _tf_f(*args, **kwargs): # Workaround b/121383831 def recorder(res): _orig_result_is_list.val = isinstance(res, list) + return res f_ = _record_result_type(recorder, f) np_out = tf.xla.experimental.compile(lambda: f_(*np_args, **kwargs)) # Workaround b/121383831 @@ -380,16 +385,34 @@ def eval_on_shapes(f, static_argnums=(), allow_static_outputs=False): """ if allow_static_outputs: def recorder(res): + def is_tensor_like(x): + return isinstance(x, (tf_np.ndarray, tf.Tensor)) _python_outputs.val = tf.nest.map_structure( - lambda x: None if isinstance(x, (tf_np.ndarray, tf.Tensor)) else x, + lambda x: None if is_tensor_like(x) else x, res) + # Set non-tensor outputs to None to avoid tf.function calling + # tf.convert_to_tensor on them. + res = tf.nest.map_structure( + lambda x: None if not is_tensor_like(x) else x, + res) + return res f = _record_result_type(recorder, f) + # When `tf_f` below is called (via get_concrete_function) with the same + # arugments (after abstraction), the Python function `f` won't be run, so we + # need this python_outputs_map to retrieve the Python outputs we've seen + # before that correspond the arguments. We could choose to record directly + # in this map instead of _python_outputs, but that requires calculating the + # key, which will duplicate the abstracting/hashing computation below. One + # can view _python_outputs as an extra output of `f` (that bypasses + # tf.function). python_outputs_map = {} + map_lock = threading.Lock() # TODO(wangpeng): tf.function could add a knob to turn off materializing the # graph, so that we don't waste computation and memory when we just want # shape inference. tf_f = jit(f, static_argnums=static_argnums).tf_function + print("f's name: {}".format(f.__name__)) # pylint: disable=missing-docstring def f_return(*args): @@ -415,25 +438,28 @@ def to_tensor_spec(x): new_args.append(tf.nest.map_structure(abstractify, arg)) if allow_static_outputs: - def _hash(args): - # TODO(wangpeng): This hash loses some structural info. Improve it. - return hash(tuple(tf.nest.flatten(args))) _python_outputs.val = None - res = tf_f.get_concrete_function(*new_args).structured_outputs + print("new args: {}".format(new_args)) + cfun = tf_f.get_concrete_function(*new_args) + res = cfun.structured_outputs res = tf.nest.map_structure(to_tensor_spec, res) if allow_static_outputs: - key = _hash(new_args) - if python_outputs_map.get(key) is None: - python_outputs_map[key] = _python_outputs.val + key = id(cfun) + map_lock.acquire() + py_values = python_outputs_map.get(key) + if py_values is None: + py_values = _python_outputs.val + python_outputs_map[key] = py_values + map_lock.release() # We can also call tf.get_static_value on structured_outputs to retrieve - # the Python values, but since we'll need to use _python_outputs to store + # the Python values, but since we'll need to use _python_outputs to record # "which outputs are static?" anyway, we choose to directly store the # Python values in _python_outputs. res = tf.nest.map_structure( lambda x, python_value: x if python_value is None else python_value, - res, python_outputs_map[key]) + res, py_values) return res @@ -600,6 +626,138 @@ def tf_dot_general(lhs, rhs, dimension_numbers, precision=None): return tf.einsum(equation, lhs, rhs) +def _conv_general_param_type_converter(window_strides, lhs_dilation, + rhs_dilation, dim): + """Convert strides, lhs_dilation, rhs_dilation to match TF convention. + + For example, + in the 3D case, if lhs_dilation = 2, then convert it to [2, 2, 2] + if lhs_dilation = (2, 2, 2), convert it also to [2, 2, 2] + + Args: + window_strides: window_strides to be converted + lhs_dilation: lhs_dilation to be converted + rhs_dilation: rhs_dilation to be converted + dim: dim to be converted + + Returns: + The updated window_strides, lhs_dilation and rhs_dilation + """ + def _as_list_of_size(item, size): + if item is None: + return None + return [item] * size if isinstance(item, int) else list(item) + return (_as_list_of_size(window_strides, dim), + _as_list_of_size(lhs_dilation, dim), + _as_list_of_size(rhs_dilation, dim)) + + +# pylint: disable=g-bad-todo +# TODO(DarrenZhang01): Expand the test cases of general convolution and revise +# the according bugs. +# TODO(DarrenZhang01): Support feature_group_count, batch_group_count and +# precision, and allow lhs_dilation and rhs_dilation to happen at the same time. +# pylint: enable=g-bad-todo +def tf_conv_general_dilated(lhs, rhs, window_strides, padding, output_shape, + lhs_dilation=None, rhs_dilation=None, + dimension_numbers=None, feature_group_count=1, + batch_group_count=1, precision=None): + """A general conv API for TensorFlow. + + According JAX version: + https://jax.readthedocs.io/en/stable/_autosummary/jax.lax.conv_general_dilated.html + + Args: + lhs: a rank n+2 dimensional input array. + rhs: a rank n+2 dimensional array of kernel weights. + window_strides: a sequence of n integers, representing the inter-window + strides. + padding: either the string ‘SAME’, the string ‘VALID’, or a sequence of n + (low, high) integer pairs that give the padding to apply before and + after each spatial dimension. + output_shape: the output shape of the convolution (only required for + transpose convolution). + lhs_dilation: None, or a sequence of n integers, giving the dilation factor + to apply in each spatial dimension of lhs. LHS dilation is + also known as transposed convolution. + rhs_dilation: None, or a sequence of n integers, giving the dilation factor + to apply in each spatial dimension of rhs. RHS dilation is + also known as atrous convolution. + dimension_numbers: either None, a ConvDimensionNumbers object, or a 3-tuple + (lhs_spec, rhs_spec, out_spec), where each element is a + string of length n+2. + feature_group_count: integer, default 1. Changing this is currently not + supported. + batch_group_count: integer, default 1. Changing this is currently not + supported. + precision: Optional. Either None, which means the default precision for the + backend, or a Precision enum value. + + Returns: + A TF NumPy array that contains the convolution result. + """ + dim = None + lhs_spec, rhs_spec, out_spec = dimension_numbers + if lhs_spec != out_spec: + raise ValueError("Current implementation requires the `data_format` of the " + "inputs and outputs to be the same.") + if len(lhs_spec) >= 6: + raise ValueError("Current implmentation does not support 4 or higher" + "dimensional convolution, but got: ", len(lhs_spec) - 2) + dim = len(lhs_spec) - 2 + if lhs_dilation and rhs_dilation: + if lhs_dilation == (1,) * dim and rhs_dilation == (1,) * dim: + lhs_dilation, rhs_dilation = None, None + else: + raise ValueError("Current implementation does not support that " + "deconvolution and dilation to be performed at the same " + "time, but got lhs_dilation: {}, rhs_dilation: {}" + .format(lhs_dilation, rhs_dilation)) + if padding not in ["SAME", "VALID"]: + raise ValueError("Current implementation requires the padding parameter" + "to be either 'VALID' or 'SAME', but got: ", padding) + if batch_group_count != 1 or feature_group_count != 1: + raise NotImplementedError("batch_group_count and feature_group_count " + "other than 1 is currently not supported, but" + " got feature_group_count: {}, batch_group_count" + ": {}".format(feature_group_count, + batch_group_count)) + if precision is not None: + raise NotImplementedError("precision other than `None` is currently not " + "supported, but got: {}".format(precision)) + # Convert params from int/Sequence[int] to list of ints. + strides, lhs_dilation, rhs_dilation = _conv_general_param_type_converter( + window_strides, lhs_dilation, rhs_dilation, dim + ) + # Preprocess the shapes + dim_maps = {} + if isinstance(lhs_spec, str): + dim_maps["I"] = list(rhs_spec).index("I") + dim_maps["O"] = list(rhs_spec).index("O") + dim_maps["N"] = list(lhs_spec).index("N") + dim_maps["C"] = list(lhs_spec).index("C") + else: + dim_maps["I"] = rhs_spec[1] + dim_maps["O"] = rhs_spec[0] + dim_maps["N"] = lhs_spec[0] + dim_maps["C"] = lhs_spec[1] + + lhs = tf_np.moveaxis(lhs, (dim_maps["N"], dim_maps["C"]), (0, dim + 1)) + # Adjust the filters, put the dimension 'I' and 'O' at last. + rhs = tf_np.moveaxis(rhs, (dim_maps["O"], dim_maps["I"]), (dim + 1, dim)) + spatial_dim_maps = {1: "W", 2: "HW", 3: "DHW"} + data_format = "N" + spatial_dim_maps[dim] + "C" + + if rhs_dilation or (lhs_dilation is None and rhs_dilation is None): + output = _tf_nn_APIs[dim][0](lhs, rhs, strides, padding, data_format, + rhs_dilation) + else: + output = _tf_nn_APIs[dim][1](lhs, rhs, tf.constant(output_shape), strides, + padding, data_format, lhs_dilation) + output = tf_np.moveaxis(output, (0, dim + 1), (dim_maps["N"], dim_maps["C"])) + return output + + def conv(inp, fltr, window_strides, @@ -1267,6 +1425,7 @@ def _get_pmap_impl(f, devices, has_tpu): # Workaround b/121383831 def recorder(res): _orig_result_is_list.val = isinstance(res, list) + return res f = _record_result_type(recorder, f) def tf_f(*tf_args): @@ -1445,17 +1604,113 @@ def accelerators(devices=None): return tpu_devices(devices) or gpu_devices(devices) -# TODO(agarwal): support axes arguments. -def vmap(f): - """Returns a function that maps `f` over first dimension of inputs.""" +def _tree_broadcast(to, s): + """Broadcasts `s` to the nested structure `to`.""" + if not isinstance(to, (list, tuple, dict)): + if not isinstance(s, (int, type(None))): + raise ValueError + return s + if isinstance(s, (int, type(None))): + return tf.nest.map_structure(lambda x: s, to) + if isinstance(to, (list, tuple)): + if len(to) != len(s): + raise ValueError + new_s = [_tree_broadcast(x, y) for x, y in zip(to, s)] + if isinstance(to, tuple): + new_s = tuple(new_s) + return new_s + elif isinstance(to, dict): + return {k: _tree_broadcast(to[k], s[k]) for k in to.keys()} + else: + raise TypeError("Unsupported type %s" % type(to)) - def _f(*args): - tf_args = tf.nest.map_structure(lambda x: tf_np.asarray(x).data, args) - def tf_f(x): - return f(*x) +def vmap(f, in_axes=0, out_axes=0): + """Returns a function that maps `f` over first dimension of inputs.""" + in_axes_flat = tf.nest.flatten(in_axes) + if not all(isinstance(l, (type(None), int)) + for l in in_axes_flat): + raise TypeError( + "vmap in_axes must be an int, None, or (nested) container with " + "those types as leaves, but got {}.".format(in_axes)) + if all(isinstance(l, type(None)) for l in in_axes_flat): + raise ValueError("vmap must have at least one non-None value in in_axes") + + out_axes_flat = tf.nest.flatten(out_axes) + if not all(isinstance(l, (type(None), int)) + for l in out_axes_flat): + raise TypeError( + "vmap out_axes must be an int, None, or (nested) container with " + "those types as leaves, but got {}.".format(out_axes)) + def _f(*args): + flat_args = tf.nest.flatten(args) + try: + f_in_axes = _tree_broadcast(args, in_axes) + except ValueError: + six.reraise( + ValueError, + ValueError( + "vmap in_axes specification must be a tree prefix of the " + r"corresponding value, got specification %s for value tree %s" % ( + in_axes, args)), + sys.exc_info()[2]) + f_in_axes_flat = tf.nest.flatten(f_in_axes) + + def tf_f(tf_args): + """Function passed to tf.vectorized_map call.""" + # Note that unbatched arguments are not passed to tf_f. Here we fill thos + # arguments back before calling `f`. + tf_flat_args = [] + j = 0 + for arg, axis in zip(flat_args, f_in_axes_flat): + if axis is None: + tf_flat_args.append(arg) + else: + tf_flat_args.append(tf_args[j]) + j += 1 + unbatched_args = tf.nest.pack_sequence_as(args, tf_flat_args) + return f(*unbatched_args) + + # Constructs arguments to pass to `tf_f`. + # Unbatch arguments are skipped. Arguments with non-zero axis are + # transposed. + tf_args = [] + for arg, axis in zip(flat_args, f_in_axes_flat): + if axis is None: + continue + arg = tf_np.asarray(arg) + if axis != 0: + arg = tf_np.moveaxis(arg, axis, 0) + tf_args.append(arg) + # TODO(agarwal): consider creating a tf.function outside of _f and reusing + # that to avoid overheads of re-vectorizing the code when running eagerly. outputs = tf.vectorized_map(tf_f, tf_args) - return tf.nest.map_structure(tf_np.asarray, outputs) + try: + f_out_axes = _tree_broadcast(outputs, out_axes) + except ValueError: + six.reraise( + ValueError, + ValueError( + "vmap out_axes specification must be a tree prefix of the " + r"corresponding value, got specification %s for value tree %s" % ( + out_axes, outputs)), + sys.exc_info()[2]) + + def map_output(x, axis): + """Maps output of tf.vectorized_map to the final output.""" + x = tf_np.asarray(x) + if axis is None: + # Note that `tf.vectorized_map always batches the outputs. + # Here we unbatch it again. + return x[0, ...] + elif axis == 0: + return x + else: + # Need to transpose the output. + return tf_np.moveaxis(x, 0, axis) + new_outputs = [map_output(output, axis) for output, axis in zip( + tf.nest.flatten(outputs), tf.nest.flatten(f_out_axes))] + return tf.nest.pack_sequence_as(outputs, new_outputs) return _f diff --git a/tf_helpers/lax.py b/tf_helpers/lax.py index 0ac4f45b..e676d1df 100644 --- a/tf_helpers/lax.py +++ b/tf_helpers/lax.py @@ -14,6 +14,13 @@ # ============================================================================== +""" +This file contains some TF-based lax utilities. Some utilities, except +general convolution, general dot and reduce window and their dependencies, +are adpated from https://github.com/google/jax/blob/master/jax/lax/lax.py. +""" + + import builtins from typing import NamedTuple, Sequence import string @@ -24,15 +31,13 @@ from tf_helpers.extensions import tf_dot_general import sys -_max = builtins.max - def conv_shape_tuple(lhs_shape, rhs_shape, strides, pads, batch_group_count=1): """Compute the shape tuple of a conv given input shapes in canonical order.""" if isinstance(pads, str): pads = padtype_to_pads(lhs_shape[2:], rhs_shape[2:], strides, pads) if len(pads) != len(lhs_shape) - 2: - msg = "Wrong number of explicit pads for convolution: expected {}, got {}." + msg = 'Wrong number of explicit pads for convolution: expected {}, got {}.' raise TypeError(msg.format(len(lhs_shape) - 2, len(pads))) lhs_padded = onp.add(lhs_shape[2:], np.sum(np.array(pads).reshape(-1, 2), @@ -63,20 +68,20 @@ class ConvDimensionNumbers(NamedTuple): def conv_general_permutations(dimension_numbers): """Utility for convolution dimension permutations relative to Conv HLO.""" lhs_spec, rhs_spec, out_spec = dimension_numbers - lhs_char, rhs_char, out_char = charpairs = ("N", "C"), ("O", "I"), ("N", "C") + lhs_char, rhs_char, out_char = charpairs = ('N', 'C'), ('O', 'I'), ('N', 'C') for i, (a, b) in enumerate(charpairs): if not dimension_numbers[i].count(a) == dimension_numbers[i].count(b) == 1: - msg = ("convolution dimension_numbers[{}] must contain the characters " - "'{}' and '{}' exactly once, got {}.") + msg = ('convolution dimension_numbers[{}] must contain the characters ' + '`{}` and `{}` exactly once, got {}.') raise TypeError(msg.format(i, a, b, dimension_numbers[i])) if len(dimension_numbers[i]) != len(set(dimension_numbers[i])): - msg = ("convolution dimension_numbers[{}] cannot have duplicate " - "characters, got {}.") + msg = ('convolution dimension_numbers[{}] cannot have duplicate ' + 'characters, got {}.') raise TypeError(msg.format(i, dimension_numbers[i])) if not (set(lhs_spec) - set(lhs_char) == set(rhs_spec) - set(rhs_char) == set(out_spec) - set(out_char)): - msg = ("convolution dimension_numbers elements must each have the same " - "set of spatial characters, got {}.") + msg = ('convolution dimension_numbers elements must each have the same ' + 'set of spatial characters, got {}.') raise TypeError(msg.format(dimension_numbers)) def getperm(spec, charpair): @@ -104,7 +109,7 @@ def conv_dimension_numbers(lhs_shape, rhs_shape, dimension_numbers): if isinstance(dimension_numbers, ConvDimensionNumbers): return dimension_numbers if len(lhs_shape) != len(rhs_shape): - msg = "convolution requires lhs and rhs ndim to be equal, got {} and {}." + msg = 'convolution requires lhs and rhs ndim to be equal, got {} and {}.' raise TypeError(msg.format(len(lhs_shape), len(rhs_shape))) if dimension_numbers is None: @@ -112,13 +117,13 @@ def conv_dimension_numbers(lhs_shape, rhs_shape, dimension_numbers): return ConvDimensionNumbers(iota, iota, iota) elif isinstance(dimension_numbers, (list, tuple)): if len(dimension_numbers) != 3: - msg = "convolution dimension_numbers list/tuple must be length 3, got {}." + msg = 'convolution dimension_numbers list/tuple must be length 3, got {}.' raise TypeError(msg.format(len(dimension_numbers))) if not all(isinstance(elt, str) for elt in dimension_numbers): - msg = "convolution dimension_numbers elements must be strings, got {}." + msg = 'convolution dimension_numbers elements must be strings, got {}.' raise TypeError(msg.format(tuple(map(type, dimension_numbers)))) - msg = ("convolution dimension_numbers[{}] must have len equal to the ndim " - "of lhs and rhs, got {} for lhs and rhs shapes {} and {}.") + msg = ('convolution dimension_numbers[{}] must have len equal to the ndim ' + 'of lhs and rhs, got {} for lhs and rhs shapes {} and {}.') for i, elt in enumerate(dimension_numbers): if len(elt) != len(lhs_shape): raise TypeError(msg.format(i, len(elt), lhs_shape, rhs_shape)) @@ -126,17 +131,17 @@ def conv_dimension_numbers(lhs_shape, rhs_shape, dimension_numbers): lhs_spec, rhs_spec, out_spec = conv_general_permutations(dimension_numbers) return ConvDimensionNumbers(lhs_spec, rhs_spec, out_spec) else: - msg = "convolution dimension_numbers must be tuple/list or None, got {}." + msg = 'convolution dimension_numbers must be tuple/list or None, got {}.' raise TypeError(msg.format(type(dimension_numbers))) def padtype_to_pads(in_shape, window_shape, window_strides, padding): - if padding == "SAME": + if padding == 'SAME': out_shape = _ceil_divide(in_shape, window_strides) pad_sizes = np.maximum(0, (out_shape - 1) * window_strides + window_shape - in_shape) return [(pad_size // 2, pad_size - pad_size // 2) for pad_size in pad_sizes] - elif padding == "VALID": + elif padding == 'VALID': return [(0, 0)] * len(in_shape) @@ -272,12 +277,12 @@ def reduce_window(inputs, init_value, reducer, window_dimensions, strides, # TensorFlow pool dimensionality checker. inputs = np.expand_dims(inputs, axis=(0, inputs.ndim)) if reducer not in [np.max, np.add]: - raise TypeError("Only max pooling and average/sum pooling are supported.") + raise TypeError('Only max pooling and average/sum pooling are supported.') # Note that there is no need to send in the parameter data format since the - # input is already of default data format - "N...C". The adjustments of the + # input is already of default data format - 'N...C'. The adjustments of the # input shape is already finished in apply_fun of Pooling in stax. - pooling = "AVG" if reducer == np.add else "MAX" + pooling = 'AVG' if reducer == np.add else 'MAX' output = np.asarray(nn.pool(inputs, window_dimensions, pooling, strides, padding)) return np.squeeze(output, axis=(0, output.ndim - 1)) * np.prod(window_dimensions) @@ -294,23 +299,23 @@ def conv_general_dilated(lhs, rhs, window_strides, padding, output_shape, lhs_di dim = None lhs_spec, rhs_spec, out_spec = dimension_numbers if lhs_spec != out_spec: - raise TypeError("Current implementation requires the `data_format` of the " - "inputs and outputs to be the same.") + raise TypeError('Current implementation requires the `data_format` of the ' + 'inputs and outputs to be the same.') if len(lhs_spec) >= 6: - raise TypeError("Current implmentation does not support 4 or higher" - "dimensional convolution, but got: ", len(lhs_spec) - 2) + raise TypeError('Current implmentation does not support 4 or higher' + 'dimensional convolution, but got: ', len(lhs_spec) - 2) dim = len(lhs_spec) - 2 if lhs_dilation and rhs_dilation: if lhs_dilation == (1,) * dim and rhs_dilation == (1,) * dim: lhs_dilation, rhs_dilation = None, None else: - raise TypeError("Current implementation does not support that deconvolution" - "and dilation to be performed at the same time, but got" - " lhs_dilation: {}, rhs_dilation: {}".format(lhs_dilation, - rhs_dilation)) - if padding not in ["SAME", "VALID"]: - raise TypeError("Current implementation requires the padding parameter" - "to be either 'VALID' or 'SAME', but got: ", padding) + raise TypeError('Current implementation does not support that ' + 'deconvolution and dilation to be performed at the same ' + 'time, but got lhs_dilation: {}, rhs_dilation: {}'.format( + lhs_dilation, rhs_dilation)) + if padding not in ['SAME', 'VALID']: + raise TypeError('Current implementation requires the padding parameter' + 'to be either `VALID` or `SAME`, but got: ', padding) # Convert params from int/Sequence[int] to list of ints. strides, lhs_dilation, rhs_dilation = _conv_general_param_type_converter( window_strides, lhs_dilation, rhs_dilation @@ -331,7 +336,7 @@ def conv_general_dilated(lhs, rhs, window_strides, padding, output_shape, lhs_di lhs = np.moveaxis(lhs, (dim_maps['N'], dim_maps['C']), (0, dim + 1)) # Adjust the filters, put the dimension 'I' and 'O' at last. rhs = np.moveaxis(rhs, (dim_maps['O'], dim_maps['I']), (dim + 1, dim)) - spatial_dim_maps = {1: 'W', 2: "HW", 3: "DHW"} + spatial_dim_maps = {1: 'W', 2: 'HW', 3: 'DHW'} data_format = 'N' + spatial_dim_maps[dim] + 'C' tf_nn_APIs = {1: [nn.conv1d, nn.conv1d_transpose], 2: [nn.conv2d, nn.conv2d_transpose], @@ -358,7 +363,7 @@ def _conv_transpose_padding(k, s, padding): else: pad_a = int(np.ceil(pad_len / 2)) elif padding == 'VALID': - pad_len = k + s - 2 + _max(k - s, 0) + pad_len = k + s - 2 + max(k - s, 0) pad_a = k - 1 else: raise ValueError('Padding mode must be `SAME` or `VALID`.') @@ -390,9 +395,9 @@ def _eval_output_shape(lhs_shape, rhs_shape, padding, window_strides): """ output_shape = [lhs_shape[0]] for i in range(1, len(lhs_shape) - 1): - if padding == "SAME": + if padding == 'SAME': output_shape.append((lhs_shape[i] - 1) * window_strides[i-1] + rhs_shape[i]) - if padding == "VALID": + if padding == 'VALID': output_shape.append((lhs_shape[i] - 1) * window_strides[i-1]) output_shape.append(lhs_shape[-1]) return tf.constant(output_shape) diff --git a/tf_helpers/stax.py b/tf_helpers/stax.py index 4a40ed99..9faa2aad 100644 --- a/tf_helpers/stax.py +++ b/tf_helpers/stax.py @@ -12,9 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Stax is a small but flexible neural net specification library from scratch. +"""This is an adapted version of stax, based on TensorFlow, originated from -For an example of its use, see examples/resnet50.py. +- https://github.com/google/jax/blob/master/jax/experimental/stax.py """ import functools