From d49ffb7b61755fb47ccd53dc782d5d352c6323b4 Mon Sep 17 00:00:00 2001 From: DarrenZhang01 <18633059886@163.com> Date: Mon, 24 Aug 2020 15:41:24 -0400 Subject: [PATCH 01/28] Add the TensorFlow version of general dot operation to `tf_hlpers` folder. --- tf_helpers/tf_dot_general.py | 135 +++++++++++++++++++++++++++++++++++ 1 file changed, 135 insertions(+) create mode 100644 tf_helpers/tf_dot_general.py diff --git a/tf_helpers/tf_dot_general.py b/tf_helpers/tf_dot_general.py new file mode 100644 index 00000000..10f61527 --- /dev/null +++ b/tf_helpers/tf_dot_general.py @@ -0,0 +1,135 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +""" +Construct an equivalent general dot operation as that in JAX - + + +Although there is an implementation in TF XLA, avoid directly using XLA when +possible. + +Zhibo Zhang, 2020.06.30 +""" + +import tensorflow as tf +from tensorflow.python.ops import numpy_ops as tf_np +import string + + +def _minus(a, b): + return [x for x in a if x not in b] + + +def compose_output_rep(lhs_rep, rhs_rep, lhs_contraction, rhs_contraction, + lhs_batch, rhs_batch): + """ Compose the output string representation. + + e.g., ij, jk, (((1,), (0,)), ((), ())) -> ik + aij, ajk, (((2,), (1,)), ((0,), (0,))) -> aik + + Args: + lhs_rep: A string representation for the left-hand side input array + rhs_rep: A string representation for the right-hand side input array + lhs_contraction: Sequence[int] (the contraction dimensions of lhs) + rhs_contraction: Sequence[int] (the contraction dimensions of rhs) + lhs_batch: Sequence[int] (the batch dimensions of lhs) + rhs_batch: Sequence[int] (the batch dimensions of rhs) + + Returns: + A string representation of the result array. + """ + output_rep = [] + for dim in lhs_batch: + output_rep.append(lhs_rep[dim]) + + for i in _minus(range(len(lhs_rep)), lhs_batch + lhs_contraction): + output_rep.append(lhs_rep[i]) + for i in _minus(range(len(rhs_rep)), rhs_batch + rhs_contraction): + output_rep.append(rhs_rep[i]) + return ''.join(output_rep) + + +def non_batched_matmul(lhs, rhs, lhs_contraction, rhs_contraction): + """ Compute the non-batched matrix multiplication. + + If it is the general non-batched/single-batched matrix multiplication, + use the highly optimized kernel `tf.tensordot` to handle it. + + Args: + lhs: an array (the left-hand side matrix/vector to be multiplied) + rhs: an array (the right-hand side matrix/vector to be multiplied) + lhs_contraction: Sequence[int] (the contraction dimensions of lhs) + rhs_contraction: Sequence[int] (the contraction dimensions of rhs) + + Returns: + An array that contains the result. + """ + return tf.tensordot(lhs, rhs, axes=(list(lhs_contraction), list(rhs_contraction))) + + +def tf_dot_general(lhs, rhs, dimension_numbers): + """ The general dot operation for TensorFlow. + + An equivalent general dot operation as that in JAX - + + Although there is an implementation in TF XLA, avoid directly using XLA when + possible. + + e.g., non-batched: ij,jk->ik + batched: ijk,ikl->ijl + + Args: + lhs: an array (the left-hand side matrix/vector to be multiplied) + rhs: an array (the right-hand side matrix/vector to be multiplied) + dimension_numbers: (Tuple[Tuple[Sequence[int], Sequence[int]], + Tuple[Sequence[int], Sequence[int]]]) – a tuple of tuples of the form + ((lhs_contracting_dims, rhs_contracting_dims), (lhs_batch_dims, rhs_batch_dims)) + + Returns: + An array that contains the result. + """ + char_list = list(string.ascii_lowercase) + char_list = char_list[8:] + char_list[:8] + lhs_rank, rhs_rank = len(lhs.shape), len(rhs.shape) + lhs_rep = char_list[:lhs_rank] + rhs_rep = char_list[lhs_rank:lhs_rank+rhs_rank] + contraction, batch = dimension_numbers + lhs_contraction, rhs_contraction = contraction + if len(lhs_contraction) != len(rhs_contraction): + raise ValueError("The input matrices are required to have the same number " + "of contraction dimensions, but got: lhs {}, rhs: {}".format( + len(lhs_contraction), len(rhs_contraction))) + lhs_batch, rhs_batch = batch + if len(lhs_batch) != len(rhs_batch): + raise ValueError("The input matrices are required to have the same number " + "of batch dimensions, but got: lhs {}, rhs: {}".format( + len(lhs_batch), len(rhs_batch))) + + if len(lhs_batch) == 0 and len(rhs_batch) == 0: + return non_batched_matmul(lhs, rhs, lhs_contraction, rhs_contraction) + + if (lhs_rank == rhs_rank == 3 and lhs_batch == (0,) and rhs_batch == (0,) + and lhs_contraction == (2,) and rhs_contraction == (1,)): + return tf.linalg.matmul(lhs, rhs) + + for i in range(len(lhs_contraction)): + rhs_rep[rhs_contraction[i]] = lhs_rep[lhs_contraction[i]] + for i in range(len(lhs_batch)): + rhs_rep[rhs_batch[i]] = lhs_rep[lhs_batch[i]] + + output_rep = compose_output_rep(lhs_rep, rhs_rep, lhs_contraction, + rhs_contraction, lhs_batch, rhs_batch) + equation = ''.join(lhs_rep) + ',' + ''.join(rhs_rep) + "->" + output_rep + return tf.einsum(equation, lhs, rhs) From c57503fa68daf13db33ad78c18d5784417800eb9 Mon Sep 17 00:00:00 2001 From: DarrenZhang01 <18633059886@163.com> Date: Mon, 24 Aug 2020 21:58:27 -0400 Subject: [PATCH 02/28] Add the TensorFlow version of some JAX lax utilities. --- tf_helpers/lax.py | 504 +++++++++++++++++++++++++++++++++++ tf_helpers/tf_dot_general.py | 135 ---------- 2 files changed, 504 insertions(+), 135 deletions(-) create mode 100644 tf_helpers/lax.py delete mode 100644 tf_helpers/tf_dot_general.py diff --git a/tf_helpers/lax.py b/tf_helpers/lax.py new file mode 100644 index 00000000..efe2de39 --- /dev/null +++ b/tf_helpers/lax.py @@ -0,0 +1,504 @@ +# Copyright 2020 The Google/TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +""" +This file contains TF equivalences for: + 1. `jax.lax.conv_general_shape_tuple` + 2. `jax.lax.conv_transpose_shape_tuple` + 3. `jax.lax.reduce_window_shape_tuple` +""" + +# from tensorflow.compiler.xla.python import xla_client +import builtins +from typing import (NamedTuple, Sequence) +import numpy as onp +import tensorflow as tf +import sys +from tf_conv_general import conv_general_dilated +from tf_reduce_window import reduce_window + +_max = builtins.max + + +#---------------------------------main APIs------------------------------------# + +def conv_shape_tuple(lhs_shape, rhs_shape, strides, pads, batch_group_count=1): + """Compute the shape tuple of a conv given input shapes in canonical order.""" + if isinstance(pads, str): + pads = padtype_to_pads(lhs_shape[2:], rhs_shape[2:], strides, pads) + if len(pads) != len(lhs_shape) - 2: + msg = "Wrong number of explicit pads for convolution: expected {}, got {}." + raise TypeError(msg.format(len(lhs_shape) - 2, len(pads))) + + lhs_padded = onp.add(lhs_shape[2:], onp.sum(onp.array(pads).reshape(-1, 2), + axis=1)) + out_space = onp.floor_divide( + onp.subtract(lhs_padded, rhs_shape[2:]), strides) + 1 + out_space = onp.maximum(0, out_space) + assert lhs_shape[0] % batch_group_count == 0 + out_shape = (lhs_shape[0] // batch_group_count, rhs_shape[0]) + return tuple(out_shape + tuple(out_space)) + + +class ConvDimensionNumbers(NamedTuple): + """Describes batch, spatial, and feature dimensions of a convolution. + Args: + lhs_spec: a tuple of nonnegative integer dimension numbers containing + `(batch dimension, feature dimension, spatial dimensions...)`. + rhs_spec: a tuple of nonnegative integer dimension numbers containing + `(out feature dimension, in feature dimension, spatial dimensions...)`. + out_spec: a tuple of nonnegative integer dimension numbers containing + `(batch dimension, feature dimension, spatial dimensions...)`. + """ + lhs_spec: Sequence[int] + rhs_spec: Sequence[int] + out_spec: Sequence[int] + + +def conv_general_permutations(dimension_numbers): + """Utility for convolution dimension permutations relative to Conv HLO.""" + lhs_spec, rhs_spec, out_spec = dimension_numbers + lhs_char, rhs_char, out_char = charpairs = ("N", "C"), ("O", "I"), ("N", "C") + for i, (a, b) in enumerate(charpairs): + if not dimension_numbers[i].count(a) == dimension_numbers[i].count(b) == 1: + msg = ("convolution dimension_numbers[{}] must contain the characters " + "'{}' and '{}' exactly once, got {}.") + raise TypeError(msg.format(i, a, b, dimension_numbers[i])) + if len(dimension_numbers[i]) != len(set(dimension_numbers[i])): + msg = ("convolution dimension_numbers[{}] cannot have duplicate " + "characters, got {}.") + raise TypeError(msg.format(i, dimension_numbers[i])) + if not (set(lhs_spec) - set(lhs_char) == set(rhs_spec) - set(rhs_char) == + set(out_spec) - set(out_char)): + msg = ("convolution dimension_numbers elements must each have the same " + "set of spatial characters, got {}.") + raise TypeError(msg.format(dimension_numbers)) + + def getperm(spec, charpair): + spatial = (i for i, c in enumerate(spec) if c not in charpair) + if spec is not rhs_spec: + spatial = sorted(spatial, key=lambda i: rhs_spec.index(spec[i])) + return (spec.index(charpair[0]), spec.index(charpair[1])) + tuple(spatial) + + lhs_perm, rhs_perm, out_perm = map(getperm, dimension_numbers, charpairs) + return lhs_perm, rhs_perm, out_perm + + +def conv_dimension_numbers(lhs_shape, rhs_shape, dimension_numbers): + """Converts convolution `dimension_numbers` to a `ConvDimensionNumbers`. + Args: + lhs_shape: tuple of nonnegative integers, shape of the convolution input. + rhs_shape: tuple of nonnegative integers, shape of the convolution kernel. + dimension_numbers: None or a tuple/list of strings or a ConvDimensionNumbers + object following the convolution dimension number specification format in + xla_client.py. + Returns: + A `ConvDimensionNumbers` object that represents `dimension_numbers` in the + canonical form used by lax functions. + """ + if isinstance(dimension_numbers, ConvDimensionNumbers): + return dimension_numbers + if len(lhs_shape) != len(rhs_shape): + msg = "convolution requires lhs and rhs ndim to be equal, got {} and {}." + raise TypeError(msg.format(len(lhs_shape), len(rhs_shape))) + + if dimension_numbers is None: + iota = tuple(range(len(lhs_shape))) + return ConvDimensionNumbers(iota, iota, iota) + elif isinstance(dimension_numbers, (list, tuple)): + if len(dimension_numbers) != 3: + msg = "convolution dimension_numbers list/tuple must be length 3, got {}." + raise TypeError(msg.format(len(dimension_numbers))) + if not all(isinstance(elt, str) for elt in dimension_numbers): + msg = "convolution dimension_numbers elements must be strings, got {}." + raise TypeError(msg.format(tuple(map(type, dimension_numbers)))) + msg = ("convolution dimension_numbers[{}] must have len equal to the ndim " + "of lhs and rhs, got {} for lhs and rhs shapes {} and {}.") + for i, elt in enumerate(dimension_numbers): + if len(elt) != len(lhs_shape): + raise TypeError(msg.format(i, len(elt), lhs_shape, rhs_shape)) + + lhs_spec, rhs_spec, out_spec = conv_general_permutations(dimension_numbers) + return ConvDimensionNumbers(lhs_spec, rhs_spec, out_spec) + else: + msg = "convolution dimension_numbers must be tuple/list or None, got {}." + raise TypeError(msg.format(type(dimension_numbers))) + + +def padtype_to_pads(in_shape, window_shape, window_strides, padding): + if padding == "SAME": + out_shape = _ceil_divide(in_shape, window_strides) + pad_sizes = onp.maximum(0, (out_shape - 1) * window_strides + + window_shape - in_shape) + return [(pad_size // 2, pad_size - pad_size // 2) for pad_size in pad_sizes] + # elif padding == PaddingType.VALID: + elif padding == "VALID": + return [(0, 0)] * len(in_shape) + + +# helper function: 1. conv_general_permutations +# 2. conv_shape_tuple +def conv_general_shape_tuple(lhs_shape, rhs_shape, window_strides, padding, + dimension_numbers): + lhs_perm, rhs_perm, out_perm = conv_general_permutations(dimension_numbers) + lhs_trans = onp.take(lhs_shape, lhs_perm) + rhs_trans = onp.take(rhs_shape, rhs_perm) + out_trans = conv_shape_tuple(lhs_trans, rhs_trans, window_strides, padding) + return tuple(onp.take(out_trans, onp.argsort(out_perm))) + + +# helper function: 1. conv_general_permutations +# 2. _conv_transpose_padding +def conv_transpose_shape_tuple(lhs_shape, rhs_shape, window_strides, padding, + dimension_numbers): + lhs_perm, rhs_perm, out_perm = conv_general_permutations(dimension_numbers) + lhs_trans = onp.take(lhs_shape, lhs_perm) + rhs_trans = onp.take(rhs_shape, rhs_perm) + if isinstance(padding, str): + padding = [_conv_transpose_padding(k, s, padding) + for k,s in zip(rhs_trans[2:], window_strides)] + padding = list(map(onp.sum, padding)) + unpad_out_space = [(i-1) * s - k + 2 + for i, k, s in zip(lhs_trans[2:], + rhs_trans[2:], + window_strides)] + out_space = onp.sum([unpad_out_space, padding], axis=0).tolist() + out_trans = tuple((lhs_trans[0], rhs_trans[0]) + tuple(out_space)) + return tuple(onp.take(out_trans, onp.argsort(out_perm))) + + +# helper function: 1. padtype_to_pads +def reduce_window_shape_tuple(operand_shape, window_dimensions, window_strides, + padding): + window_dimensions = (1,) + window_dimensions + (1,) + window_strides = (1,) + window_strides + (1,) + pads = padtype_to_pads(operand_shape, window_dimensions, window_strides, padding) + operand_padded = onp.add(operand_shape, onp.add(*zip(*pads))) + t = onp.floor_divide( + onp.subtract(operand_padded, window_dimensions), window_strides) + 1 + return tuple(t) + + +# helper function: 1. conv_dimension_numbers +# 2. _conv_transpose_padding +# 3. conv_general_dilated +def conv_transpose(lhs, rhs, strides, padding, + rhs_dilation=None, dimension_numbers=None, + transpose_kernel=False, precision=None): + """Convenience wrapper for calculating the N-d convolution "transpose". + This function directly calculates a fractionally strided conv rather than + indirectly calculating the gradient (transpose) of a forward convolution. + + Args: + lhs: a rank `n+2` dimensional input array. + rhs: a rank `n+2` dimensional array of kernel weights. + strides: sequence of `n` integers, sets fractional stride. + padding: 'SAME', 'VALID' will set as transpose of corresponding forward + conv, or a sequence of `n` integer 2-tuples describing before-and-after + padding for each `n` spatial dimension. + rhs_dilation: `None`, or a sequence of `n` integers, giving the + dilation factor to apply in each spatial dimension of `rhs`. RHS dilation + is also known as atrous convolution. + dimension_numbers: tuple of dimension descriptors as in + lax.conv_general_dilated. Defaults to tensorflow convention. + transpose_kernel: if True flips spatial axes and swaps the input/output + channel axes of the kernel. This makes the output of this function identical + to the gradient-derived functions like keras.layers.Conv2DTranspose + applied to the same kernel. For typical use in neural nets this is completely + pointless and just makes input/output channel specification confusing. + precision: Optional. Either `None`, which means the default precision for + the backend, or a `Precision` enum value. + + Returns: + Transposed N-d convolution, with output padding following the conventions of + keras.layers.Conv2DTranspose. + """ + assert len(lhs.shape) == len(rhs.shape) and len(lhs.shape) > 2 + ndims = len(lhs.shape) + one = (1,) * (ndims - 2) + # Set dimensional layout defaults if not specified. + if dimension_numbers is None: + if ndims == 3: + dimension_numbers = ('NHC', 'HIO', 'NHC') + elif ndims == 4: + dimension_numbers = ('NHWC', 'HWIO', 'NHWC') + elif ndims == 5: + dimension_numbers = ('NHWDC', 'HWDIO', 'NHWDC') + else: + raise ValueError('No 4+ dimensional dimension_number defaults.') + dn = conv_dimension_numbers(lhs.shape, rhs.shape, dimension_numbers) + k_shape = onp.take(rhs.shape, dn.rhs_spec) + k_sdims = k_shape[2:] + # Calculate correct output shape given padding and strides. + pads: Union[str, Sequence[Tuple[int, int]]] + if padding in {'SAME', 'VALID'}: + if rhs_dilation is None: + rhs_dilation = (1,) * (rhs.ndim - 2) + effective_k_size = map(lambda k, r: (k-1) * r + 1, k_sdims, rhs_dilation) + pads = [_conv_transpose_padding(k, s, padding) + for k,s in zip(effective_k_size, strides)] + else: + pads = padding + if transpose_kernel: + # flip spatial dims and swap input / output channel axes + rhs = _flip_axes(rhs, onp.array(dn.rhs_spec)[2:]) + rhs = onp.swapaxes(rhs, dn.rhs_spec[0], dn.rhs_spec[1]) + # return conv_general_dilated(lhs, rhs, one, pads, strides, rhs_dilation, dn, + # precision=precision) + return conv_general_dilated(lhs, rhs, one, pads, strides, rhs_dilation, dn) + + +def tf_dot_general(lhs, rhs, dimension_numbers, precision=None): + """ The general dot operation for TensorFlow. + + An equivalent general dot operation as that in JAX - + + Although there is an implementation in TF XLA, avoid directly using XLA when + possible. + + e.g., non-batched: ij,jk->ik + batched: ijk,ikl->ijl + + Args: + lhs: an array (the left-hand side matrix/vector to be multiplied) + rhs: an array (the right-hand side matrix/vector to be multiplied) + dimension_numbers: (Tuple[Tuple[Sequence[int], Sequence[int]], + Tuple[Sequence[int], Sequence[int]]]) – a tuple of tuples of the form + ((lhs_contracting_dims, rhs_contracting_dims), (lhs_batch_dims, rhs_batch_dims)) + + Returns: + An array that contains the result. + """ + char_list = list(string.ascii_lowercase) + char_list = char_list[8:] + char_list[:8] + lhs_rank, rhs_rank = len(lhs.shape), len(rhs.shape) + lhs_rep = char_list[:lhs_rank] + rhs_rep = char_list[lhs_rank:lhs_rank+rhs_rank] + contraction, batch = dimension_numbers + lhs_contraction, rhs_contraction = contraction + if len(lhs_contraction) != len(rhs_contraction): + raise ValueError("The input matrices are required to have the same number " + "of contraction dimensions, but got: lhs {}, rhs: {}".format( + len(lhs_contraction), len(rhs_contraction))) + lhs_batch, rhs_batch = batch + if len(lhs_batch) != len(rhs_batch): + raise ValueError("The input matrices are required to have the same number " + "of batch dimensions, but got: lhs {}, rhs: {}".format( + len(lhs_batch), len(rhs_batch))) + + if len(lhs_batch) == 0 and len(rhs_batch) == 0: + return _non_batched_matmul(lhs, rhs, lhs_contraction, rhs_contraction) + + if (lhs_rank == rhs_rank == 3 and lhs_batch == (0,) and rhs_batch == (0,) + and lhs_contraction == (2,) and rhs_contraction == (1,)): + return np.asarray(tf.linalg.matmul(lhs, rhs)) + + for i in range(len(lhs_contraction)): + rhs_rep[rhs_contraction[i]] = lhs_rep[lhs_contraction[i]] + for i in range(len(lhs_batch)): + rhs_rep[rhs_batch[i]] = lhs_rep[lhs_batch[i]] + + output_rep = _compose_output_rep(lhs_rep, rhs_rep, lhs_contraction, + rhs_contraction, lhs_batch, rhs_batch) + equation = ''.join(lhs_rep) + ',' + ''.join(rhs_rep) + "->" + output_rep + return np.asarray(tf.einsum(equation, lhs, rhs)) + + +def reduce_window(inputs, init_value, reducer, window_dimensions, strides, + padding): + if reducer not in [np.max, np.add]: + raise TypeError("Only max pooling and average/sum pooling are supported.") + + # Note that there is no need to send in the parameter data format since the + # input is already of default data format - "N...C". The adjustments of the + # input shape is already finished in apply_fun of Pooling in stax. + pooling = "AVG" if pooling_type == "SUM" else pooling_type + output = pool(inputs, window_dimensions, pooling, strides, padding) + if pooling_type in ["MAX", "AVG"]: + return output + # If it is sum pooling, mutiply the output by the number of grids inside a + # window. + # grids = onp.prod(list(window_dimensions)) + return np.asarray(output) + + +# TODO (Zhibo Zhang): Support feature_group_count, batch_group_count and precision, and +# allow lhs_dilation and rhs_dilation to happen at the same time. +def conv_general_dilated(lhs, rhs, window_strides, padding, lhs_dilation=None, + rhs_dilation=None, dimension_numbers=None, + feature_group_count=1, batch_group_count=1, precision=None): + """ A general conv API that integrates normal conv, deconvolution, + dilated convolution, etc.""" + # raise TypeError("lhs shape: {}, rhs shape: {}".format(lhs.shape, rhs.shape)) + dim = None + lhs_spec, rhs_spec, out_spec = dimension_numbers + if lhs_spec != out_spec: + raise TypeError("Current implementation requires the `data_format` of the" + "inputs and outputs to be the same.") + if len(lhs_spec) >= 6: + raise TypeError("Current implmentation does not support 4 or higher" + "dimensional convolution, but got: ", len(lhs_spec) - 2) + dim = len(lhs_spec) - 2 + if lhs_dilation and rhs_dilation: + if lhs_dilation == (1,) * dim and rhs_dilation == (1,) * dim: + lhs_dilation, rhs_dilation = None, None + else: + raise TypeError("Current implementation does not support that deconvolution" + "and dilation to be performed at the same time, but got" + " lhs_dilation: {}, rhs_dilation: {}".format(lhs_dilation, + rhs_dilation)) + print("the dim is: {}".format(dim)) + if padding not in ["SAME", "VALID"]: + raise TypeError("Current implementation requires the padding parameter" + "to be either 'VALID' or 'SAME', but got: ", padding) + # Convert params from int/Sequence[int] to list of ints. + strides, lhs_dilation, rhs_dilation = _conv_general_param_type_converter( + window_strides, lhs_dilation, rhs_dilation + ) + # Preprocess the shapes + dim_maps = {} + if isinstance(lhs_spec, str): + dim_maps['I'] = list(rhs_spec).index('I') + dim_maps['O'] = list(rhs_spec).index('O') + dim_maps['N'] = list(lhs_spec).index('N') + dim_maps['C'] = list(lhs_spec).index('C') + else: + dim_maps['I'] = rhs_spec[1] + dim_maps['O'] = rhs_spec[0] + dim_maps['N'] = lhs_spec[0] + dim_maps['C'] = lhs_spec[1] + # data_format, lhs = conv_dim_translator(lhs, lhs_spec, dim) + lhs = np.moveaxis(lhs, (dim_maps['N'], dim_maps['C']), (0, dim + 1)) + # Adjust the filters, put the dimension 'I' and 'O' at last. + rhs = np.moveaxis(rhs, (dim_maps['O'], dim_maps['I']), (dim + 1, dim)) + spatial_dim_maps = {1: 'W', 2: "HW", 3: "DHW"} + data_format = 'N' + spatial_dim_maps[dim] + 'C' + print("data format: {}".format(data_format)) + tf_nn_APIs = {1: [nn.conv1d, nn.conv1d_transpose], + 2: [nn.conv2d, nn.conv2d_transpose], + 3: [nn.conv3d, nn.conv3d_transpose]} + + output = None + if rhs_dilation or (lhs_dilation is None and rhs_dilation is None): + output = tf_nn_APIs[dim][0](lhs, rhs, strides, padding, data_format, rhs_dilation) + else: + output_shape = _eval_output_shape(lhs.shape, rhs.shape, padding, window_strides) + output = tf_nn_APIs[dim][1](lhs, rhs, output_shape, strides, padding, data_format, lhs_dilation) + output = np.moveaxis(output, (0, dim + 1), (dim_maps['N'], dim_maps['C'])) + return np.asarray(output) + + +#-------------------------------private methods------------------------------# + +def _ceil_divide(x1, x2): + return -onp.floor_divide(onp.negative(x1), x2) + + +def _conv_transpose_padding(k, s, padding): + + if padding == 'SAME': + pad_len = k + s - 2 + if s > k - 1: + pad_a = k - 1 + else: + pad_a = int(onp.ceil(pad_len / 2)) + elif padding == 'VALID': + pad_len = k + s - 2 + _max(k - s, 0) + pad_a = k - 1 + else: + raise ValueError('Padding mode must be `SAME` or `VALID`.') + pad_b = pad_len - pad_a + return pad_a, pad_b + + +def _minus(a, b): + return [x for x in a if x not in b] + + +def _compose_output_rep(lhs_rep, rhs_rep, lhs_contraction, rhs_contraction, + lhs_batch, rhs_batch): + """ Compose the output string representation. + + e.g., ij, jk, (((1,), (0,)), ((), ())) -> ik + aij, ajk, (((2,), (1,)), ((0,), (0,))) -> aik + + Args: + lhs_rep: A string representation for the left-hand side input array + rhs_rep: A string representation for the right-hand side input array + lhs_contraction: Sequence[int] (the contraction dimensions of lhs) + rhs_contraction: Sequence[int] (the contraction dimensions of rhs) + lhs_batch: Sequence[int] (the batch dimensions of lhs) + rhs_batch: Sequence[int] (the batch dimensions of rhs) + + Returns: + A string representation of the result array. + """ + output_rep = [] + for dim in lhs_batch: + output_rep.append(lhs_rep[dim]) + + for i in _minus(range(len(lhs_rep)), lhs_batch + lhs_contraction): + output_rep.append(lhs_rep[i]) + for i in _minus(range(len(rhs_rep)), rhs_batch + rhs_contraction): + output_rep.append(rhs_rep[i]) + return ''.join(output_rep) + + +def _non_batched_matmul(lhs, rhs, lhs_contraction, rhs_contraction): + """ Compute the non-batched matrix multiplication. + + If it is the general non-batched/single-batched matrix multiplication, + use the highly optimized kernel `tf.tensordot` to handle it. + + Args: + lhs: an array (the left-hand side matrix/vector to be multiplied) + rhs: an array (the right-hand side matrix/vector to be multiplied) + lhs_contraction: Sequence[int] (the contraction dimensions of lhs) + rhs_contraction: Sequence[int] (the contraction dimensions of rhs) + + Returns: + An array that contains the result. + """ + return np.asarray( + tf.tensordot(lhs, rhs, axes=(list(lhs_contraction), list(rhs_contraction)))) + + +def _conv_general_param_type_converter(window_strides, lhs_dilation, rhs_dilation): + """ Convert the inputs strides, lhs_dilation, rhs_dilation to the standard + TF conv inputs. + + For example, + in the 3D case, if lhs_dilation = 2, then convert it to [2, 2, 2] + if lhs_dilation = (2, 2, 2), convert it also to [2, 2, 2] + """ + strides = [window_strides] * dim if isinstance(window_strides, int) else \ + list(window_strides) + if lhs_dilation: + lhs_dilation = [lhs_dilation] * dim if isinstance(lhs_dilation, int) else \ + list(lhs_dilation) + if rhs_dilation: + rhs_dilation = [rhs_dilation] * dim if isinstance(rhs_dilation, int) else \ + list(rhs_dilation) + return (strides, lhs_dilation, rhs_dilation) + + +def _eval_output_shape(lhs_shape, rhs_shape, padding, window_strides): + """ Evaluate the output shape in for transpose convolutions. + """ + output_shape = [lhs_shape[0]] + for i in range(1, len(lhs_shape) - 1): + output_shape.append((lhs_shape[i] - 1) * window_strides[i-1] + rhs_shape[i]) + output_shape.append(lhs_shape[-1]) + return tf.constant(output_shape) diff --git a/tf_helpers/tf_dot_general.py b/tf_helpers/tf_dot_general.py deleted file mode 100644 index 10f61527..00000000 --- a/tf_helpers/tf_dot_general.py +++ /dev/null @@ -1,135 +0,0 @@ -# Copyright 2020 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -""" -Construct an equivalent general dot operation as that in JAX - - - -Although there is an implementation in TF XLA, avoid directly using XLA when -possible. - -Zhibo Zhang, 2020.06.30 -""" - -import tensorflow as tf -from tensorflow.python.ops import numpy_ops as tf_np -import string - - -def _minus(a, b): - return [x for x in a if x not in b] - - -def compose_output_rep(lhs_rep, rhs_rep, lhs_contraction, rhs_contraction, - lhs_batch, rhs_batch): - """ Compose the output string representation. - - e.g., ij, jk, (((1,), (0,)), ((), ())) -> ik - aij, ajk, (((2,), (1,)), ((0,), (0,))) -> aik - - Args: - lhs_rep: A string representation for the left-hand side input array - rhs_rep: A string representation for the right-hand side input array - lhs_contraction: Sequence[int] (the contraction dimensions of lhs) - rhs_contraction: Sequence[int] (the contraction dimensions of rhs) - lhs_batch: Sequence[int] (the batch dimensions of lhs) - rhs_batch: Sequence[int] (the batch dimensions of rhs) - - Returns: - A string representation of the result array. - """ - output_rep = [] - for dim in lhs_batch: - output_rep.append(lhs_rep[dim]) - - for i in _minus(range(len(lhs_rep)), lhs_batch + lhs_contraction): - output_rep.append(lhs_rep[i]) - for i in _minus(range(len(rhs_rep)), rhs_batch + rhs_contraction): - output_rep.append(rhs_rep[i]) - return ''.join(output_rep) - - -def non_batched_matmul(lhs, rhs, lhs_contraction, rhs_contraction): - """ Compute the non-batched matrix multiplication. - - If it is the general non-batched/single-batched matrix multiplication, - use the highly optimized kernel `tf.tensordot` to handle it. - - Args: - lhs: an array (the left-hand side matrix/vector to be multiplied) - rhs: an array (the right-hand side matrix/vector to be multiplied) - lhs_contraction: Sequence[int] (the contraction dimensions of lhs) - rhs_contraction: Sequence[int] (the contraction dimensions of rhs) - - Returns: - An array that contains the result. - """ - return tf.tensordot(lhs, rhs, axes=(list(lhs_contraction), list(rhs_contraction))) - - -def tf_dot_general(lhs, rhs, dimension_numbers): - """ The general dot operation for TensorFlow. - - An equivalent general dot operation as that in JAX - - - Although there is an implementation in TF XLA, avoid directly using XLA when - possible. - - e.g., non-batched: ij,jk->ik - batched: ijk,ikl->ijl - - Args: - lhs: an array (the left-hand side matrix/vector to be multiplied) - rhs: an array (the right-hand side matrix/vector to be multiplied) - dimension_numbers: (Tuple[Tuple[Sequence[int], Sequence[int]], - Tuple[Sequence[int], Sequence[int]]]) – a tuple of tuples of the form - ((lhs_contracting_dims, rhs_contracting_dims), (lhs_batch_dims, rhs_batch_dims)) - - Returns: - An array that contains the result. - """ - char_list = list(string.ascii_lowercase) - char_list = char_list[8:] + char_list[:8] - lhs_rank, rhs_rank = len(lhs.shape), len(rhs.shape) - lhs_rep = char_list[:lhs_rank] - rhs_rep = char_list[lhs_rank:lhs_rank+rhs_rank] - contraction, batch = dimension_numbers - lhs_contraction, rhs_contraction = contraction - if len(lhs_contraction) != len(rhs_contraction): - raise ValueError("The input matrices are required to have the same number " - "of contraction dimensions, but got: lhs {}, rhs: {}".format( - len(lhs_contraction), len(rhs_contraction))) - lhs_batch, rhs_batch = batch - if len(lhs_batch) != len(rhs_batch): - raise ValueError("The input matrices are required to have the same number " - "of batch dimensions, but got: lhs {}, rhs: {}".format( - len(lhs_batch), len(rhs_batch))) - - if len(lhs_batch) == 0 and len(rhs_batch) == 0: - return non_batched_matmul(lhs, rhs, lhs_contraction, rhs_contraction) - - if (lhs_rank == rhs_rank == 3 and lhs_batch == (0,) and rhs_batch == (0,) - and lhs_contraction == (2,) and rhs_contraction == (1,)): - return tf.linalg.matmul(lhs, rhs) - - for i in range(len(lhs_contraction)): - rhs_rep[rhs_contraction[i]] = lhs_rep[lhs_contraction[i]] - for i in range(len(lhs_batch)): - rhs_rep[rhs_batch[i]] = lhs_rep[lhs_batch[i]] - - output_rep = compose_output_rep(lhs_rep, rhs_rep, lhs_contraction, - rhs_contraction, lhs_batch, rhs_batch) - equation = ''.join(lhs_rep) + ',' + ''.join(rhs_rep) + "->" + output_rep - return tf.einsum(equation, lhs, rhs) From 7f5135073aeced7a14aa3f2c2db5bdd815046895 Mon Sep 17 00:00:00 2001 From: DarrenZhang01 <18633059886@163.com> Date: Tue, 25 Aug 2020 15:54:35 -0400 Subject: [PATCH 03/28] Add the updated `tf_lax` file and the updated `lax_tests` file --- tf_helpers/lax.py | 31 ++++++++---- tf_helpers/lax_tests.py | 109 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 129 insertions(+), 11 deletions(-) create mode 100644 tf_helpers/lax_tests.py diff --git a/tf_helpers/lax.py b/tf_helpers/lax.py index efe2de39..cf153c3b 100644 --- a/tf_helpers/lax.py +++ b/tf_helpers/lax.py @@ -23,8 +23,11 @@ # from tensorflow.compiler.xla.python import xla_client import builtins from typing import (NamedTuple, Sequence) +import string import numpy as onp +from tensorflow.python.ops import numpy_ops as np import tensorflow as tf +from tensorflow import nn import sys from tf_conv_general import conv_general_dilated from tf_reduce_window import reduce_window @@ -260,7 +263,7 @@ def conv_transpose(lhs, rhs, strides, padding, return conv_general_dilated(lhs, rhs, one, pads, strides, rhs_dilation, dn) -def tf_dot_general(lhs, rhs, dimension_numbers, precision=None): +def dot_general(lhs, rhs, dimension_numbers, precision=None): """ The general dot operation for TensorFlow. An equivalent general dot operation as that in JAX - @@ -317,26 +320,28 @@ def tf_dot_general(lhs, rhs, dimension_numbers, precision=None): def reduce_window(inputs, init_value, reducer, window_dimensions, strides, - padding): + padding, base_dilation=None, window_dilation=None): if reducer not in [np.max, np.add]: raise TypeError("Only max pooling and average/sum pooling are supported.") # Note that there is no need to send in the parameter data format since the # input is already of default data format - "N...C". The adjustments of the # input shape is already finished in apply_fun of Pooling in stax. - pooling = "AVG" if pooling_type == "SUM" else pooling_type - output = pool(inputs, window_dimensions, pooling, strides, padding) - if pooling_type in ["MAX", "AVG"]: - return output + pooling = "AVG" if reducer == np.add else "MAX" + output = nn.pool(inputs, window_dimensions, pooling, strides, padding) + # if pooling_type in ["MAX", "AVG"]: + # return output # If it is sum pooling, mutiply the output by the number of grids inside a # window. # grids = onp.prod(list(window_dimensions)) return np.asarray(output) +# TOTO (Zhibo Zhang): Expand the test cases of general convolution and revise +# the according bugs. # TODO (Zhibo Zhang): Support feature_group_count, batch_group_count and precision, and # allow lhs_dilation and rhs_dilation to happen at the same time. -def conv_general_dilated(lhs, rhs, window_strides, padding, lhs_dilation=None, +def conv_general_dilated(lhs, rhs, window_strides, padding, output_shape, lhs_dilation=None, rhs_dilation=None, dimension_numbers=None, feature_group_count=1, batch_group_count=1, precision=None): """ A general conv API that integrates normal conv, deconvolution, @@ -345,7 +350,7 @@ def conv_general_dilated(lhs, rhs, window_strides, padding, lhs_dilation=None, dim = None lhs_spec, rhs_spec, out_spec = dimension_numbers if lhs_spec != out_spec: - raise TypeError("Current implementation requires the `data_format` of the" + raise TypeError("Current implementation requires the `data_format` of the " "inputs and outputs to be the same.") if len(lhs_spec) >= 6: raise TypeError("Current implmentation does not support 4 or higher" @@ -394,8 +399,8 @@ def conv_general_dilated(lhs, rhs, window_strides, padding, lhs_dilation=None, if rhs_dilation or (lhs_dilation is None and rhs_dilation is None): output = tf_nn_APIs[dim][0](lhs, rhs, strides, padding, data_format, rhs_dilation) else: - output_shape = _eval_output_shape(lhs.shape, rhs.shape, padding, window_strides) - output = tf_nn_APIs[dim][1](lhs, rhs, output_shape, strides, padding, data_format, lhs_dilation) + # output_shape = _eval_output_shape(lhs.shape, rhs.shape, padding, window_strides) + output = tf_nn_APIs[dim][1](lhs, rhs, tf.constant(output_shape), strides, padding, data_format, lhs_dilation) output = np.moveaxis(output, (0, dim + 1), (dim_maps['N'], dim_maps['C'])) return np.asarray(output) @@ -499,6 +504,10 @@ def _eval_output_shape(lhs_shape, rhs_shape, padding, window_strides): """ output_shape = [lhs_shape[0]] for i in range(1, len(lhs_shape) - 1): - output_shape.append((lhs_shape[i] - 1) * window_strides[i-1] + rhs_shape[i]) + if padding == "SAME": + output_shape.append((lhs_shape[i] - 1) * window_strides[i-1] + rhs_shape[i]) + if padding == "VALID": + output_shape.append((lhs_shape[i] - 1) * window_strides[i-1]) output_shape.append(lhs_shape[-1]) + print("output shape: {}".format(output_shape)) return tf.constant(output_shape) diff --git a/tf_helpers/lax_tests.py b/tf_helpers/lax_tests.py new file mode 100644 index 00000000..d789030f --- /dev/null +++ b/tf_helpers/lax_tests.py @@ -0,0 +1,109 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + + +""" +Tests for the general conv operation for TensorFlow. + +Zhibo Zhang, 2020.06.07 +""" +from tensorflow import nn +import tensorflow as tf +import lax +from tensorflow.python.platform import test +from absl.testing import parameterized +import itertools +import numpy as onp +from tensorflow.python.ops import numpy_ops as tfnp +from jax import numpy as jnp +import jax +import sys + +class TFConvGeneralTest(tf.test.TestCase, parameterized.TestCase): + + + @parameterized.parameters( + {"lhs_np": onp.ones((5, 3)), "rhs_np": onp.ones((3, 2)), + "dims": (((1,), (0,)), ((), ()))}, + {"lhs_np": onp.ones((5, 3)), "rhs_np": onp.ones((5, 3)), + "dims": (((0, 1), (0, 1)), ((), ()))}, + {"lhs_np": onp.ones((5, 3, 2)), "rhs_np": onp.ones((2, 3, 2)), + "dims": (((1, 2), (1, 0)), ((), ()))}, + {"lhs_np": onp.ones((6, 5, 3)), "rhs_np": onp.ones((6, 3, 2)), + "dims": (((2,), (1,)), ((0,), (0,)))}, + {"lhs_np": onp.ones((6, 3, 5)), "rhs_np": onp.ones((6, 3, 2)), + "dims": (((1,), (1,)), ((0,), (0,)))}, + {"lhs_np": onp.ones((5, 3, 2, 2)), "rhs_np": onp.ones((5, 2, 2, 6)), + "dims": (((2, 3), (1, 2)), ((0,), (0,)))}, + {"lhs_np": onp.ones((2, 2, 5, 3)), "rhs_np": onp.ones((2, 2, 3, 2)), + "dims": (((3,), (2,)), ((0, 1), (0, 1)))}, + {"lhs_np": onp.ones((2, 2, 5, 2)), "rhs_np": onp.ones((2, 2, 3, 2)), + "dims": (((3,), (1,)), ((0,), (0,)))}, + {"lhs_np": onp.ones((2, 2, 5, 3, 3)), "rhs_np": onp.ones((2, 3, 2, 3, 2)), + "dims": (((4,), (1,)), ((0,), (0,)))}, + ) + def test_tf_dot_general(self, lhs_np, rhs_np, dims): + ans = jax.lax.dot_general(lhs_np, rhs_np, dims) + result = lax.dot_general(lhs_np, rhs_np, dims) + self.assertAllClose(result, tfnp.array(ans)) + + + @parameterized.named_parameters([ + ("_lhs_shape={}_rhs_shape={}_strides={}_padding={}" + "_lhs_dilation={}_rhs_dilation={}" + "_feature_group_count={}_batch_group_count={}_dims={}" + "_perms={}".format(lhs_shape, rhs_shape, + strides, padding, lhs_dilation, rhs_dilation, + feature_group_count, batch_group_count, ",".join(dimension_numbers), perms), + lhs_shape, rhs_shape, strides, padding, lhs_dilation, rhs_dilation, + feature_group_count, batch_group_count, dimension_numbers, perms) + for batch_group_count, feature_group_count in [(1, 1)] + for lhs_shape, rhs_shape in [ + ((b * batch_group_count, i * feature_group_count, 9, w), + (j * feature_group_count * batch_group_count, i, 4, 5)) + for w in [0, 10] + for b, i, j in itertools.product([2, 3], repeat=3)] + for strides in [(1, 1), (2, 1)] + for padding in ['SAME'] + for lhs_dilation, rhs_dilation in [ + (None, (1, 1)) + ] + for dimension_numbers, perms in [ + (("NHWC", "HWIO", "NHWC"), ([0, 2, 3, 1], [2, 3, 1, 0])) + ]]) + def testConvGeneralDilated(self, lhs_shape, rhs_shape, strides, + padding, lhs_dilation, rhs_dilation, + feature_group_count, batch_group_count, + dimension_numbers, perms): + tf.print("dimension_numbers: {}".format(dimension_numbers), output_stream=sys.stdout) + lhs_perm, rhs_perm = perms # permute to compatible shapes + + lhs_tf = tfnp.transpose(tfnp.ones(lhs_shape), lhs_perm) + rhs_tf = tfnp.transpose(tfnp.ones(rhs_shape), rhs_perm) + + lhs_jax = jnp.transpose(jnp.ones(lhs_shape), lhs_perm) + rhs_jax = jnp.transpose(jnp.ones(rhs_shape), rhs_perm) + + jax_conv = jax.lax.conv_general_dilated(lhs_jax, rhs_jax, strides, padding, lhs_dilation, + rhs_dilation, dimension_numbers, feature_group_count, batch_group_count) + + tf_conv = lax.conv_general_dilated(lhs_tf, rhs_tf, strides, padding, jax_conv.shape, lhs_dilation, + rhs_dilation, dimension_numbers, feature_group_count, batch_group_count) + + self.assertAllEqual(tf_conv, tfnp.asarray(jax_conv)) + + +if __name__ == "__main__": + test.main() From 36ae0a355444ef9da67297bfabe7270f2423efa3 Mon Sep 17 00:00:00 2001 From: DarrenZhang01 <18633059886@163.com> Date: Tue, 25 Aug 2020 16:12:11 -0400 Subject: [PATCH 04/28] Remove the docstring at the begining of the file --- tf_helpers/lax.py | 6 ------ tf_helpers/lax_tests.py | 5 ----- 2 files changed, 11 deletions(-) diff --git a/tf_helpers/lax.py b/tf_helpers/lax.py index cf153c3b..8045db8a 100644 --- a/tf_helpers/lax.py +++ b/tf_helpers/lax.py @@ -13,12 +13,6 @@ # limitations under the License. # ============================================================================== -""" -This file contains TF equivalences for: - 1. `jax.lax.conv_general_shape_tuple` - 2. `jax.lax.conv_transpose_shape_tuple` - 3. `jax.lax.reduce_window_shape_tuple` -""" # from tensorflow.compiler.xla.python import xla_client import builtins diff --git a/tf_helpers/lax_tests.py b/tf_helpers/lax_tests.py index d789030f..054ee782 100644 --- a/tf_helpers/lax_tests.py +++ b/tf_helpers/lax_tests.py @@ -14,11 +14,6 @@ # ============================================================================== -""" -Tests for the general conv operation for TensorFlow. - -Zhibo Zhang, 2020.06.07 -""" from tensorflow import nn import tensorflow as tf import lax From 70924d09c6115b424ad8bfd80c87ba4e29ebfe1e Mon Sep 17 00:00:00 2001 From: DarrenZhang01 <18633059886@163.com> Date: Tue, 25 Aug 2020 16:17:11 -0400 Subject: [PATCH 05/28] Remove unnecessary imports. --- tf_helpers/lax.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tf_helpers/lax.py b/tf_helpers/lax.py index 8045db8a..166f6725 100644 --- a/tf_helpers/lax.py +++ b/tf_helpers/lax.py @@ -23,8 +23,6 @@ import tensorflow as tf from tensorflow import nn import sys -from tf_conv_general import conv_general_dilated -from tf_reduce_window import reduce_window _max = builtins.max From d20c45ff336a1fd304ff919a967cde9e47d13132 Mon Sep 17 00:00:00 2001 From: DarrenZhang01 <18633059886@163.com> Date: Tue, 25 Aug 2020 16:28:10 -0400 Subject: [PATCH 06/28] Update Travis CI such that its installation contains the TensorFlow-related ecosystem and the tests folder contain both the `lax` and its test cases. --- .travis.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.travis.yml b/.travis.yml index 9932154e..40605067 100644 --- a/.travis.yml +++ b/.travis.yml @@ -5,10 +5,10 @@ python: - "3.8" install: - pip install --upgrade pip - - pip install numpy jaxlib tensorflow tensorflow-datasets --upgrade + - pip install pygame==2.0.0.dev6 tfp-nightly tfds-nightly numpy jaxlib jax tf-nightly matplotlib more-itertools --upgrade - pip install git+https://github.com/google/jax.git - pip install -e . + - cp tf_helpers/lax.py tests/ script: - set -e - for f in tests/*.py; do python $f; done - From 2fee307a67c7c35bba991bb0ec17a35d8a359f17 Mon Sep 17 00:00:00 2001 From: DarrenZhang01 <18633059886@163.com> Date: Tue, 25 Aug 2020 18:01:48 -0400 Subject: [PATCH 07/28] Rename the lax test files and move it under the test folder --- tf_helpers/lax_tests.py => tests/lax_test.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename tf_helpers/lax_tests.py => tests/lax_test.py (100%) diff --git a/tf_helpers/lax_tests.py b/tests/lax_test.py similarity index 100% rename from tf_helpers/lax_tests.py rename to tests/lax_test.py From 34787ba670e75528f4fedf793757e938e2ca45e0 Mon Sep 17 00:00:00 2001 From: DarrenZhang01 <18633059886@163.com> Date: Tue, 25 Aug 2020 18:04:34 -0400 Subject: [PATCH 08/28] Adjust the blank lines above the class definition and between the methods --- tests/lax_test.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/lax_test.py b/tests/lax_test.py index 054ee782..f64bff27 100644 --- a/tests/lax_test.py +++ b/tests/lax_test.py @@ -26,8 +26,8 @@ import jax import sys -class TFConvGeneralTest(tf.test.TestCase, parameterized.TestCase): +class TFConvGeneralTest(tf.test.TestCase, parameterized.TestCase): @parameterized.parameters( {"lhs_np": onp.ones((5, 3)), "rhs_np": onp.ones((3, 2)), @@ -54,7 +54,6 @@ def test_tf_dot_general(self, lhs_np, rhs_np, dims): result = lax.dot_general(lhs_np, rhs_np, dims) self.assertAllClose(result, tfnp.array(ans)) - @parameterized.named_parameters([ ("_lhs_shape={}_rhs_shape={}_strides={}_padding={}" "_lhs_dilation={}_rhs_dilation={}" From 6aa01546356b7a90ed9dd4b74fb9d0f5bc84d9d2 Mon Sep 17 00:00:00 2001 From: DarrenZhang01 <18633059886@163.com> Date: Tue, 25 Aug 2020 18:08:34 -0400 Subject: [PATCH 09/28] Remove unused comments as pointed out in https://github.com/google/neural-tangents/pull/59#discussion_r476739437 --- tf_helpers/lax.py | 25 ++----------------------- 1 file changed, 2 insertions(+), 23 deletions(-) diff --git a/tf_helpers/lax.py b/tf_helpers/lax.py index 166f6725..e6838256 100644 --- a/tf_helpers/lax.py +++ b/tf_helpers/lax.py @@ -14,9 +14,8 @@ # ============================================================================== -# from tensorflow.compiler.xla.python import xla_client import builtins -from typing import (NamedTuple, Sequence) +from typing import NamedTuple, Sequence import string import numpy as onp from tensorflow.python.ops import numpy_ops as np @@ -27,8 +26,6 @@ _max = builtins.max -#---------------------------------main APIs------------------------------------# - def conv_shape_tuple(lhs_shape, rhs_shape, strides, pads, batch_group_count=1): """Compute the shape tuple of a conv given input shapes in canonical order.""" if isinstance(pads, str): @@ -143,8 +140,6 @@ def padtype_to_pads(in_shape, window_shape, window_strides, padding): return [(0, 0)] * len(in_shape) -# helper function: 1. conv_general_permutations -# 2. conv_shape_tuple def conv_general_shape_tuple(lhs_shape, rhs_shape, window_strides, padding, dimension_numbers): lhs_perm, rhs_perm, out_perm = conv_general_permutations(dimension_numbers) @@ -154,8 +149,6 @@ def conv_general_shape_tuple(lhs_shape, rhs_shape, window_strides, padding, return tuple(onp.take(out_trans, onp.argsort(out_perm))) -# helper function: 1. conv_general_permutations -# 2. _conv_transpose_padding def conv_transpose_shape_tuple(lhs_shape, rhs_shape, window_strides, padding, dimension_numbers): lhs_perm, rhs_perm, out_perm = conv_general_permutations(dimension_numbers) @@ -174,7 +167,6 @@ def conv_transpose_shape_tuple(lhs_shape, rhs_shape, window_strides, padding, return tuple(onp.take(out_trans, onp.argsort(out_perm))) -# helper function: 1. padtype_to_pads def reduce_window_shape_tuple(operand_shape, window_dimensions, window_strides, padding): window_dimensions = (1,) + window_dimensions + (1,) @@ -186,9 +178,6 @@ def reduce_window_shape_tuple(operand_shape, window_dimensions, window_strides, return tuple(t) -# helper function: 1. conv_dimension_numbers -# 2. _conv_transpose_padding -# 3. conv_general_dilated def conv_transpose(lhs, rhs, strides, padding, rhs_dilation=None, dimension_numbers=None, transpose_kernel=False, precision=None): @@ -250,8 +239,6 @@ def conv_transpose(lhs, rhs, strides, padding, # flip spatial dims and swap input / output channel axes rhs = _flip_axes(rhs, onp.array(dn.rhs_spec)[2:]) rhs = onp.swapaxes(rhs, dn.rhs_spec[0], dn.rhs_spec[1]) - # return conv_general_dilated(lhs, rhs, one, pads, strides, rhs_dilation, dn, - # precision=precision) return conv_general_dilated(lhs, rhs, one, pads, strides, rhs_dilation, dn) @@ -321,11 +308,6 @@ def reduce_window(inputs, init_value, reducer, window_dimensions, strides, # input shape is already finished in apply_fun of Pooling in stax. pooling = "AVG" if reducer == np.add else "MAX" output = nn.pool(inputs, window_dimensions, pooling, strides, padding) - # if pooling_type in ["MAX", "AVG"]: - # return output - # If it is sum pooling, mutiply the output by the number of grids inside a - # window. - # grids = onp.prod(list(window_dimensions)) return np.asarray(output) @@ -376,7 +358,7 @@ def conv_general_dilated(lhs, rhs, window_strides, padding, output_shape, lhs_di dim_maps['O'] = rhs_spec[0] dim_maps['N'] = lhs_spec[0] dim_maps['C'] = lhs_spec[1] - # data_format, lhs = conv_dim_translator(lhs, lhs_spec, dim) + lhs = np.moveaxis(lhs, (dim_maps['N'], dim_maps['C']), (0, dim + 1)) # Adjust the filters, put the dimension 'I' and 'O' at last. rhs = np.moveaxis(rhs, (dim_maps['O'], dim_maps['I']), (dim + 1, dim)) @@ -391,14 +373,11 @@ def conv_general_dilated(lhs, rhs, window_strides, padding, output_shape, lhs_di if rhs_dilation or (lhs_dilation is None and rhs_dilation is None): output = tf_nn_APIs[dim][0](lhs, rhs, strides, padding, data_format, rhs_dilation) else: - # output_shape = _eval_output_shape(lhs.shape, rhs.shape, padding, window_strides) output = tf_nn_APIs[dim][1](lhs, rhs, tf.constant(output_shape), strides, padding, data_format, lhs_dilation) output = np.moveaxis(output, (0, dim + 1), (dim_maps['N'], dim_maps['C'])) return np.asarray(output) -#-------------------------------private methods------------------------------# - def _ceil_divide(x1, x2): return -onp.floor_divide(onp.negative(x1), x2) From 42e5b545edc6cae91af1a80a95c2dc896b477042 Mon Sep 17 00:00:00 2001 From: DarrenZhang01 <18633059886@163.com> Date: Tue, 25 Aug 2020 18:18:24 -0400 Subject: [PATCH 10/28] Remove the installation of JAX in Travis CI --- .travis.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.travis.yml b/.travis.yml index 40605067..25470b99 100644 --- a/.travis.yml +++ b/.travis.yml @@ -5,10 +5,10 @@ python: - "3.8" install: - pip install --upgrade pip - - pip install pygame==2.0.0.dev6 tfp-nightly tfds-nightly numpy jaxlib jax tf-nightly matplotlib more-itertools --upgrade + - pip install pygame==2.0.0.dev6 tfp-nightly tfds-nightly numpy jaxlib tf-nightly matplotlib more-itertools --upgrade - pip install git+https://github.com/google/jax.git - pip install -e . - - cp tf_helpers/lax.py tests/ + - cp tf_helpers/* tests/ script: - set -e - for f in tests/*.py; do python $f; done From 7dc72d6993fd415dcdc977214a206a515364542c Mon Sep 17 00:00:00 2001 From: DarrenZhang01 <18633059886@163.com> Date: Tue, 25 Aug 2020 18:32:09 -0400 Subject: [PATCH 11/28] Revert back to the original Travis CI file as mentioned in https://github.com/google/neural-tangents/pull/59#discussion_r476793363 --- .travis.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.travis.yml b/.travis.yml index 25470b99..9b7d912f 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,3 +1,4 @@ + language: python python: - "3.6" @@ -5,10 +6,9 @@ python: - "3.8" install: - pip install --upgrade pip - - pip install pygame==2.0.0.dev6 tfp-nightly tfds-nightly numpy jaxlib tf-nightly matplotlib more-itertools --upgrade + - pip install numpy jaxlib tensorflow tensorflow-datasets --upgrade - pip install git+https://github.com/google/jax.git - pip install -e . - - cp tf_helpers/* tests/ script: - set -e - for f in tests/*.py; do python $f; done From e0fabe395f568a50292cbd550fae49c31e65121d Mon Sep 17 00:00:00 2001 From: DarrenZhang01 <18633059886@163.com> Date: Tue, 25 Aug 2020 18:35:56 -0400 Subject: [PATCH 12/28] Remove the `_non_batched_matmul` as suggested in https://github.com/google/neural-tangents/pull/59#discussion_r476798497. --- tf_helpers/lax.py | 22 ++-------------------- 1 file changed, 2 insertions(+), 20 deletions(-) diff --git a/tf_helpers/lax.py b/tf_helpers/lax.py index e6838256..33b9b39f 100644 --- a/tf_helpers/lax.py +++ b/tf_helpers/lax.py @@ -281,7 +281,8 @@ def dot_general(lhs, rhs, dimension_numbers, precision=None): len(lhs_batch), len(rhs_batch))) if len(lhs_batch) == 0 and len(rhs_batch) == 0: - return _non_batched_matmul(lhs, rhs, lhs_contraction, rhs_contraction) + return np.asarray( + tf.tensordot(lhs, rhs, axes=(list(lhs_contraction), list(rhs_contraction)))) if (lhs_rank == rhs_rank == 3 and lhs_batch == (0,) and rhs_batch == (0,) and lhs_contraction == (2,) and rhs_contraction == (1,)): @@ -432,25 +433,6 @@ def _compose_output_rep(lhs_rep, rhs_rep, lhs_contraction, rhs_contraction, return ''.join(output_rep) -def _non_batched_matmul(lhs, rhs, lhs_contraction, rhs_contraction): - """ Compute the non-batched matrix multiplication. - - If it is the general non-batched/single-batched matrix multiplication, - use the highly optimized kernel `tf.tensordot` to handle it. - - Args: - lhs: an array (the left-hand side matrix/vector to be multiplied) - rhs: an array (the right-hand side matrix/vector to be multiplied) - lhs_contraction: Sequence[int] (the contraction dimensions of lhs) - rhs_contraction: Sequence[int] (the contraction dimensions of rhs) - - Returns: - An array that contains the result. - """ - return np.asarray( - tf.tensordot(lhs, rhs, axes=(list(lhs_contraction), list(rhs_contraction)))) - - def _conv_general_param_type_converter(window_strides, lhs_dilation, rhs_dilation): """ Convert the inputs strides, lhs_dilation, rhs_dilation to the standard TF conv inputs. From 4dbc0b33103e6bd694085adb7a298a18836487e8 Mon Sep 17 00:00:00 2001 From: DarrenZhang01 <18633059886@163.com> Date: Tue, 25 Aug 2020 18:38:24 -0400 Subject: [PATCH 13/28] Remove the extra blank line in Travis CI --- .travis.yml | 1 - 1 file changed, 1 deletion(-) diff --git a/.travis.yml b/.travis.yml index 9b7d912f..ec32e3ee 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,4 +1,3 @@ - language: python python: - "3.6" From 412eedccb1675daf5d83016157195e6948b3e855 Mon Sep 17 00:00:00 2001 From: DarrenZhang01 <18633059886@163.com> Date: Tue, 25 Aug 2020 19:14:25 -0400 Subject: [PATCH 14/28] Give it a try on the direct file import from the `tf_helpers` folder --- tests/lax_test.py | 2 +- tf_helpers/__init__.py | 0 2 files changed, 1 insertion(+), 1 deletion(-) create mode 100644 tf_helpers/__init__.py diff --git a/tests/lax_test.py b/tests/lax_test.py index f64bff27..52a95dcc 100644 --- a/tests/lax_test.py +++ b/tests/lax_test.py @@ -16,7 +16,7 @@ from tensorflow import nn import tensorflow as tf -import lax +import tf_helpers import lax from tensorflow.python.platform import test from absl.testing import parameterized import itertools diff --git a/tf_helpers/__init__.py b/tf_helpers/__init__.py new file mode 100644 index 00000000..e69de29b From 96c0ff1ce3939408baecd6a47b0aab80e640133a Mon Sep 17 00:00:00 2001 From: DarrenZhang01 <18633059886@163.com> Date: Tue, 25 Aug 2020 19:42:21 -0400 Subject: [PATCH 15/28] Remove the extra blank line and revise the import typo. --- tests/lax_test.py | 2 +- tf_helpers/lax.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/lax_test.py b/tests/lax_test.py index 52a95dcc..8080c8cd 100644 --- a/tests/lax_test.py +++ b/tests/lax_test.py @@ -16,7 +16,7 @@ from tensorflow import nn import tensorflow as tf -import tf_helpers import lax +from tf_helpers import lax from tensorflow.python.platform import test from absl.testing import parameterized import itertools diff --git a/tf_helpers/lax.py b/tf_helpers/lax.py index 33b9b39f..d4a68ce2 100644 --- a/tf_helpers/lax.py +++ b/tf_helpers/lax.py @@ -384,7 +384,6 @@ def _ceil_divide(x1, x2): def _conv_transpose_padding(k, s, padding): - if padding == 'SAME': pad_len = k + s - 2 if s > k - 1: From 23d998d6880e1bb898e2fad74ab18d3176aa84ef Mon Sep 17 00:00:00 2001 From: DarrenZhang01 <18633059886@163.com> Date: Tue, 25 Aug 2020 20:36:32 -0400 Subject: [PATCH 16/28] Add the TF version of `ostax` purely for the use in Neural Tangents. --- tf_helpers/tf_jax_stax.py | 338 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 338 insertions(+) create mode 100644 tf_helpers/tf_jax_stax.py diff --git a/tf_helpers/tf_jax_stax.py b/tf_helpers/tf_jax_stax.py new file mode 100644 index 00000000..2a5fe970 --- /dev/null +++ b/tf_helpers/tf_jax_stax.py @@ -0,0 +1,338 @@ +# Copyright 2018 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Stax is a small but flexible neural net specification library from scratch. + +For an example of its use, see examples/resnet50.py. +""" + +import functools +import itertools +import operator as op + +import sys +import tensorflow as tf +import tensorflow_probability as tfp +from tensorflow.python.ops import numpy_ops as tfnp +from tf_helpers import lax +from tf_shape_conversion import shape_conversion +import numpy as onp +from stateless_random_ops import split +from stateless_random_ops import stateless_random_normal as rn +from tensorflow.random import stateless_uniform + +from tensorflow.nn import (relu, log_softmax, softmax, softplus, sigmoid, elu, + leaky_relu, selu) +from tensorflow import zeros_initializer as zi +from tensorflow import ones_initializer as oi + +# Following the convention used in Keras and tf.layers, we use CamelCase for the +# names of layer constructors, like Conv and Relu, while using snake_case for +# other functions, like tfnp.conv and relu. + +# Each layer constructor function returns an (init_fun, apply_fun) pair, where +# init_fun: takes an rng key and an input shape and returns an +# (output_shape, params) pair, +# apply_fun: takes params, inputs, and an rng key and applies the layer. + + +def Dense(out_dim, W_init=rn, b_init=rn): + """Layer constructor function for a dense (fully-connected) layer.""" + def init_fun(rng, input_shape): + output_shape = input_shape[:-1] + (out_dim,) + keys = split(seed=tf.convert_to_tensor(rng, dtype=tf.int32), num=2) + k1 = keys[0] + k2 = keys[1] + # convert the two keys from shape (2,) into a scalar + k1 = stateless_uniform(shape=[], seed=k1, minval=None, maxval=None, dtype=tf.int32) + k2 = stateless_uniform(shape=[], seed=k2, minval=None, maxval=None, dtype=tf.int32) + W = W_init(seed=k1, shape=(input_shape[-1], out_dim)) + b = b_init(seed=k2, shape=(out_dim,)) + return tfnp.zeros(output_shape), (W.numpy(), b.numpy()) + def apply_fun(params, inputs, **kwargs): + W, b = params + return tfnp.dot(inputs, W) + b + return init_fun, apply_fun + + +def GeneralConv(dimension_numbers, out_chan, filter_shape, + strides=None, padding='VALID', W_init=rn, + b_init=rn): + """Layer construction function for a general convolution layer.""" + lhs_spec, rhs_spec, out_spec = dimension_numbers + one = (1,) * len(filter_shape) + strides = strides or one + def init_fun(rng, input_shape): + input_shape = shape_conversion(input_shape) + filter_shape_iter = iter(filter_shape) + kernel_shape = [out_chan if c == 'O' else + input_shape[lhs_spec.index('C')] if c == 'I' else + next(filter_shape_iter) for c in rhs_spec] + output_shape = lax.conv_general_shape_tuple( + input_shape, kernel_shape, strides, padding, dimension_numbers) + bias_shape = [out_chan if c == 'C' else 1 for c in out_spec] + bias_shape = tuple(itertools.dropwhile(lambda x: x == 1, bias_shape)) + keys = split(seed=tf.convert_to_tensor(rng, dtype=tf.int32), num=2) + k1 = keys[0] + k2 = keys[1] + W = W_init(seed=k1, shape=kernel_shape) + b = b_init(stddev=1e-6, seed=k2, shape=bias_shape) + return tfnp.zeros(output_shape), (W, b) + def apply_fun(params, inputs, **kwargs): + W, b = params + return lax.conv_general_dilated(inputs, W, strides, padding, one, one, + dimension_numbers=dimension_numbers) + b + return init_fun, apply_fun +Conv = functools.partial(GeneralConv, ('NHWC', 'HWIO', 'NHWC')) + + +def elementwise(fun, **fun_kwargs): + """Layer that applies a scalar function elementwise on its inputs.""" + def init_fun(rng, input_shape): + return (tfnp.zeros(input_shape), ()) + # init_fun = lambda rng, input_shape: (tfnp.zeros(input_shape), ()) + apply_fun = lambda params, inputs, **kwargs: fun(inputs, **fun_kwargs) + return init_fun, apply_fun +Tanh = elementwise(tfnp.tanh) +Relu = elementwise(relu) +Exp = elementwise(tfnp.exp) +LogSoftmax = elementwise(log_softmax, axis=-1) +Softmax = elementwise(softmax, axis=-1) +Softplus = elementwise(softplus) +Sigmoid = elementwise(sigmoid) +Elu = elementwise(elu) +LeakyRelu = elementwise(leaky_relu) +Selu = elementwise(selu) + + +def _pooling_layer(reducer, init_val, rescaler=None): + def PoolingLayer(window_shape, strides=None, padding='VALID', spec=None): + """Layer construction function for a pooling layer.""" + strides = strides or (1,) * len(window_shape) + rescale = rescaler(window_shape, strides, padding) if rescaler else None + + dim = len(window_shape) + batch_dim, channel_dim = None, None + if spec is None: + batch_dim, channel_dim = 0, len(window_shape) + 1 + else: + batch_dim, channel_dim = spec.index('N'), spec.index('C') + window_shape = window_shape + strides = strides + + def init_fun(rng, input_shape): + # Move the batch and channel dimension of the input shape such + # that it is of data format "NHWC" + shape = [input_shape[batch_dim]] + for i in range(len(input_shape)): + if i not in [batch_dim, channel_dim]: + shape.append(input_shape[i]) + shape.append(input_shape[channel_dim]) + out_shape = lax.reduce_window_shape_tuple(shape, window_shape, + strides, padding) + return tfnp.zeros(out_shape), () + def apply_fun(params, inputs, **kwargs): + inputs = onp.moveaxis(inputs, (batch_dim, channel_dim), \ + (0, dim + 1)) + output = lax.reduce_window(inputs, init_val, reducer, window_shape, + strides, padding) + return rescale(out, inputs, spec) if rescale else out + # return output + return tfnp.array(output) + return init_fun, apply_fun + return PoolingLayer +MaxPool = _pooling_layer(tfnp.max, -tfnp.inf) + + +def _normalize_by_window_size(dims, strides, padding): + def rescale(outputs, inputs, spec): + if spec is None: + non_spatial_axes = 0, inputs.ndim - 1 + else: + non_spatial_axes = spec.index('N'), spec.index('C') + + spatial_shape = tuple(inputs.shape[i] + for i in range(inputs.ndim) + if i not in non_spatial_axes) + one = tfnp.ones(spatial_shape, dtype=inputs.dtype) + window_sizes = lax.reduce_window(one, 0., tfnp.add, dims, strides, padding) + for i in sorted(non_spatial_axes): + window_sizes = tfnp.expand_dims(window_sizes, i) + + return outputs * window_sizes + return rescale +SumPool = _pooling_layer(tfnp.add, 0., _normalize_by_window_size) +AvgPool = _pooling_layer(tfnp.add, 0.) + + +def Flatten(): + """Layer construction function for flattening all but the leading dim.""" + def init_fun(rng, input_shape): + output_shape = input_shape[0], functools.reduce(op.mul, input_shape[1:], 1) + return tfnp.zeros(output_shape), () + def apply_fun(params, inputs, **kwargs): + return tfnp.reshape(inputs, (inputs.shape[0], -1)) + return init_fun, apply_fun +Flatten = Flatten() + + +def Identity(): + """Layer construction function for an identity layer.""" + init_fun = lambda rng, input_shape: (tfnp.zeros(input_shape), ()) + apply_fun = lambda params, inputs, **kwargs: inputs + return init_fun, apply_fun +Identity = Identity() + + +def FanOut(num): + """Layer construction function for a fan-out layer.""" + def init_fun(rng, input_shape): + return ([tfnp.zeros(input_shape)] * num, ()) + apply_fun = lambda params, inputs, **kwargs: [inputs] * num + return init_fun, apply_fun + + +def FanInSum(): + """Layer construction function for a fan-in sum layer.""" + init_fun = lambda rng, input_shape: (tfnp.zeros(input_shape[0]), ()) + apply_fun = lambda params, inputs, **kwargs: sum(inputs) + return init_fun, apply_fun +FanInSum = FanInSum() + + +def FanInConcat(axis=-1): + """Layer construction function for a fan-in concatenation layer.""" + def init_fun(rng, input_shape): + ax = axis % len(input_shape[0]) + concat_size = sum(shape[ax] for shape in input_shape) + out_shape = input_shape[0][:ax] + (concat_size,) + input_shape[0][ax+1:] + return tfnp.zeros(out_shape), () + def apply_fun(params, inputs, **kwargs): + return tfnp.concatenate(inputs, axis) + return init_fun, apply_fun + + +def Dropout(rate, mode='train'): + """Layer construction function for a dropout layer with given rate.""" + def init_fun(rng, input_shape): + return tfnp.zeros(input_shape), () + def apply_fun(params, inputs, **kwargs): + rng = kwargs.get('rng', None) + if rng is None: + msg = ("Dropout layer requires apply_fun to be called with a PRNG key " + "argument. That is, instead of `apply_fun(params, inputs)`, call " + "it like `apply_fun(params, inputs, rng)` where `rng` is a " + "jax.random.PRNGKey value.") + raise ValueError(msg) + if mode == 'train': + prob = tf.ones(inputs.shape) * rate + keep = stateless_uniform(shape=inputs.shape, seed=rng, minval=0, maxval=1) < prob + return tfnp.where(keep, inputs / rate, 0) + else: + return inputs + return init_fun, apply_fun + + +# Composing layers via combinators +def serial(*layers): + """Combinator for composing layers in serial. + + Args: + *layers: a sequence of layers, each an (init_fun, apply_fun) pair. + + Returns: + A new layer, meaning an (init_fun, apply_fun) pair, representing the serial + composition of the given sequence of layers. + """ + nlayers = len(layers) + init_funs, apply_funs = zip(*layers) + def init_fun(rng, input_shape): + params = [] + i = 0 + for init_fun in init_funs: + i += 1 + keys = split(seed=tf.convert_to_tensor(rng, dtype=tf.int32), num=2) + rng = keys[0] + layer_rng = keys[1] + input_shape = shape_conversion(input_shape) + input_shape, param = init_fun(layer_rng, input_shape) + params.append(param) + return input_shape, params + def apply_fun(params, inputs, **kwargs): + rng = kwargs.pop('rng', None) + rngs = None + if rng is not None: + rngs = split(seed=tf.convert_to_tensor(rng, dtype=tf.int32), num=nlayers) + else: + rngs = (None,) * nlayers + for i in range(nlayers): + inputs = apply_funs[i](params[i], inputs, rng=rngs[i], **kwargs) + return inputs + return init_fun, apply_fun + + +def parallel(*layers): + """Combinator for composing layers in parallel. + + The layer resulting from this combinator is often used with the FanOut and + FanInSum layers. + + Args: + *layers: a sequence of layers, each an (init_fun, apply_fun) pair. + + Returns: + A new layer, meaning an (init_fun, apply_fun) pair, representing the + parallel composition of the given sequence of layers. In particular, the + returned layer takes a sequence of inputs and returns a sequence of outputs + with the same length as the argument `layers`. + """ + nlayers = len(layers) + init_funs, apply_funs = zip(*layers) + def init_fun(rng, input_shape): + rngs = split(seed=tf.convert_to_tensor(rng, dtype=tf.int32), num=nlayers) + result = [] + for i in range(nlayers): + result.append(init_funs[i](rngs[i], input_shape[i])) + return zip(*result) + def apply_fun(params, inputs, **kwargs): + rng = kwargs.pop('rng', None) + rngs = None + if rng is not None: + rngs = split(seed=tf.convert_to_tensor(rng, dtype=tf.int32), num=nlayers) + else: + rngs = (None,) * nlayers + result = [] + for i in range(len(apply_funs)): + result.append(apply_funs[i](params[i], inputs[i], rng=rngs[i], **kwargs)) + return result + return init_fun, apply_fun + + +def shape_dependent(make_layer): + """Combinator to delay layer constructor pair until input shapes are known. + + Args: + make_layer: a one-argument function that takes an input shape as an argument + (a tuple of positive integers) and returns an (init_fun, apply_fun) pair. + + Returns: + A new layer, meaning an (init_fun, apply_fun) pair, representing the same + layer as returned by `make_layer` but with its construction delayed until + input shapes are known. + """ + def init_fun(rng, input_shape): + return make_layer(input_shape)[0](rng, input_shape) + def apply_fun(params, inputs, **kwargs): + return make_layer(inputs.shape)[1](params, inputs, **kwargs) + return init_fun, apply_fun From c04e7b83667f80790fe177b5e8b7aab80299ef99 Mon Sep 17 00:00:00 2001 From: DarrenZhang01 <18633059886@163.com> Date: Wed, 26 Aug 2020 12:52:34 -0400 Subject: [PATCH 17/28] Rename `tf_jax_stax` to `stax` --- tf_helpers/{tf_jax_stax.py => stax.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename tf_helpers/{tf_jax_stax.py => stax.py} (100%) diff --git a/tf_helpers/tf_jax_stax.py b/tf_helpers/stax.py similarity index 100% rename from tf_helpers/tf_jax_stax.py rename to tf_helpers/stax.py From 70257b9fb4c713244d92a78d108bcb96edf073c1 Mon Sep 17 00:00:00 2001 From: DarrenZhang01 <18633059886@163.com> Date: Wed, 26 Aug 2020 12:55:37 -0400 Subject: [PATCH 18/28] Remove the unused print statement --- tf_helpers/lax.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/tf_helpers/lax.py b/tf_helpers/lax.py index d4a68ce2..c7df2bc5 100644 --- a/tf_helpers/lax.py +++ b/tf_helpers/lax.py @@ -339,7 +339,6 @@ def conv_general_dilated(lhs, rhs, window_strides, padding, output_shape, lhs_di "and dilation to be performed at the same time, but got" " lhs_dilation: {}, rhs_dilation: {}".format(lhs_dilation, rhs_dilation)) - print("the dim is: {}".format(dim)) if padding not in ["SAME", "VALID"]: raise TypeError("Current implementation requires the padding parameter" "to be either 'VALID' or 'SAME', but got: ", padding) @@ -365,7 +364,6 @@ def conv_general_dilated(lhs, rhs, window_strides, padding, output_shape, lhs_di rhs = np.moveaxis(rhs, (dim_maps['O'], dim_maps['I']), (dim + 1, dim)) spatial_dim_maps = {1: 'W', 2: "HW", 3: "DHW"} data_format = 'N' + spatial_dim_maps[dim] + 'C' - print("data format: {}".format(data_format)) tf_nn_APIs = {1: [nn.conv1d, nn.conv1d_transpose], 2: [nn.conv2d, nn.conv2d_transpose], 3: [nn.conv3d, nn.conv3d_transpose]} @@ -461,5 +459,4 @@ def _eval_output_shape(lhs_shape, rhs_shape, padding, window_strides): if padding == "VALID": output_shape.append((lhs_shape[i] - 1) * window_strides[i-1]) output_shape.append(lhs_shape[-1]) - print("output shape: {}".format(output_shape)) return tf.constant(output_shape) From d92e0fc9c8955ccc015dbb9deae2a5d5aca0faf2 Mon Sep 17 00:00:00 2001 From: DarrenZhang01 <18633059886@163.com> Date: Wed, 26 Aug 2020 12:57:05 -0400 Subject: [PATCH 19/28] Remove the unused `tf nn` import in lax tests. --- tests/lax_test.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/lax_test.py b/tests/lax_test.py index 8080c8cd..44da65da 100644 --- a/tests/lax_test.py +++ b/tests/lax_test.py @@ -14,7 +14,6 @@ # ============================================================================== -from tensorflow import nn import tensorflow as tf from tf_helpers import lax from tensorflow.python.platform import test From 94322960386205b43bc69a8dd741ce3993bebe8f Mon Sep 17 00:00:00 2001 From: DarrenZhang01 <18633059886@163.com> Date: Wed, 26 Aug 2020 14:48:45 -0400 Subject: [PATCH 20/28] Rename the lax tests --- tests/lax_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/lax_test.py b/tests/lax_test.py index 44da65da..f4501782 100644 --- a/tests/lax_test.py +++ b/tests/lax_test.py @@ -26,7 +26,7 @@ import sys -class TFConvGeneralTest(tf.test.TestCase, parameterized.TestCase): +class TFLaxTest(tf.test.TestCase, parameterized.TestCase): @parameterized.parameters( {"lhs_np": onp.ones((5, 3)), "rhs_np": onp.ones((3, 2)), From 51f78a4a268684aef7dbc2fdf62e9f7603f33c74 Mon Sep 17 00:00:00 2001 From: DarrenZhang01 <18633059886@163.com> Date: Wed, 26 Aug 2020 14:51:02 -0400 Subject: [PATCH 21/28] Remove the unused comments --- tf_helpers/lax.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tf_helpers/lax.py b/tf_helpers/lax.py index c7df2bc5..2310a5fb 100644 --- a/tf_helpers/lax.py +++ b/tf_helpers/lax.py @@ -135,7 +135,6 @@ def padtype_to_pads(in_shape, window_shape, window_strides, padding): pad_sizes = onp.maximum(0, (out_shape - 1) * window_strides + window_shape - in_shape) return [(pad_size // 2, pad_size - pad_size // 2) for pad_size in pad_sizes] - # elif padding == PaddingType.VALID: elif padding == "VALID": return [(0, 0)] * len(in_shape) @@ -321,7 +320,6 @@ def conv_general_dilated(lhs, rhs, window_strides, padding, output_shape, lhs_di feature_group_count=1, batch_group_count=1, precision=None): """ A general conv API that integrates normal conv, deconvolution, dilated convolution, etc.""" - # raise TypeError("lhs shape: {}, rhs shape: {}".format(lhs.shape, rhs.shape)) dim = None lhs_spec, rhs_spec, out_spec = dimension_numbers if lhs_spec != out_spec: From 9357a243ec8d3c685cee47faee2bd6c94dc3a0fa Mon Sep 17 00:00:00 2001 From: DarrenZhang01 <18633059886@163.com> Date: Wed, 26 Aug 2020 16:16:10 -0400 Subject: [PATCH 22/28] Add an extra `batch` dimension and an extra `channel` dimension to pass the TF pool shape checker in order to make the TF `reduce_window` API consistent with JAX `reduce_window`. --- tf_helpers/lax.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tf_helpers/lax.py b/tf_helpers/lax.py index 2310a5fb..425062bd 100644 --- a/tf_helpers/lax.py +++ b/tf_helpers/lax.py @@ -300,6 +300,9 @@ def dot_general(lhs, rhs, dimension_numbers, precision=None): def reduce_window(inputs, init_value, reducer, window_dimensions, strides, padding, base_dilation=None, window_dilation=None): + # Add an extra "batch" dimension and an extra "channel" dimension to pass the + # TensorFlow pool dimensionality checker. + inputs = np.expand_dims(inputs, axis=(0, inputs.ndim)) if reducer not in [np.max, np.add]: raise TypeError("Only max pooling and average/sum pooling are supported.") From 27a69822ca02a21f5f809f585e1d6f30b1364947 Mon Sep 17 00:00:00 2001 From: DarrenZhang01 <18633059886@163.com> Date: Wed, 26 Aug 2020 16:29:58 -0400 Subject: [PATCH 23/28] Add the window_shape and strides dimension expansion that appear in JAX stax back to TF stax. --- tf_helpers/stax.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tf_helpers/stax.py b/tf_helpers/stax.py index 2a5fe970..8da049b4 100644 --- a/tf_helpers/stax.py +++ b/tf_helpers/stax.py @@ -128,6 +128,11 @@ def PoolingLayer(window_shape, strides=None, padding='VALID', spec=None): batch_dim, channel_dim = 0, len(window_shape) + 1 else: batch_dim, channel_dim = spec.index('N'), spec.index('C') + + for i in (batch_dim, channel_dim): + window_shape = window_shape[:i] + (1,) + window_shape[i:] + strides = strides[:i] + (1,) + strides[i:] + window_shape = window_shape strides = strides From a020275318c30824a51973c8327acb2a8ebde3c4 Mon Sep 17 00:00:00 2001 From: DarrenZhang01 <18633059886@163.com> Date: Wed, 26 Aug 2020 16:38:40 -0400 Subject: [PATCH 24/28] Revert the changes in Travis CI. --- .travis.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.travis.yml b/.travis.yml index ec32e3ee..9932154e 100644 --- a/.travis.yml +++ b/.travis.yml @@ -11,3 +11,4 @@ install: script: - set -e - for f in tests/*.py; do python $f; done + From 7df8a230d272d338d61f9b5c22eb271d1436ff14 Mon Sep 17 00:00:00 2001 From: DarrenZhang01 <18633059886@163.com> Date: Wed, 26 Aug 2020 16:46:05 -0400 Subject: [PATCH 25/28] Replace the vanilla NumPy support with TF NumPy support. --- tf_helpers/lax.py | 42 +++++++++++++++++++++--------------------- tf_helpers/stax.py | 4 ++-- 2 files changed, 23 insertions(+), 23 deletions(-) diff --git a/tf_helpers/lax.py b/tf_helpers/lax.py index 425062bd..ee78144c 100644 --- a/tf_helpers/lax.py +++ b/tf_helpers/lax.py @@ -34,11 +34,11 @@ def conv_shape_tuple(lhs_shape, rhs_shape, strides, pads, batch_group_count=1): msg = "Wrong number of explicit pads for convolution: expected {}, got {}." raise TypeError(msg.format(len(lhs_shape) - 2, len(pads))) - lhs_padded = onp.add(lhs_shape[2:], onp.sum(onp.array(pads).reshape(-1, 2), + lhs_padded = onp.add(lhs_shape[2:], np.sum(np.array(pads).reshape(-1, 2), axis=1)) - out_space = onp.floor_divide( - onp.subtract(lhs_padded, rhs_shape[2:]), strides) + 1 - out_space = onp.maximum(0, out_space) + out_space = np.floor_divide( + np.subtract(lhs_padded, rhs_shape[2:]), strides) + 1 + out_space = np.maximum(0, out_space) assert lhs_shape[0] % batch_group_count == 0 out_shape = (lhs_shape[0] // batch_group_count, rhs_shape[0]) return tuple(out_shape + tuple(out_space)) @@ -132,7 +132,7 @@ def conv_dimension_numbers(lhs_shape, rhs_shape, dimension_numbers): def padtype_to_pads(in_shape, window_shape, window_strides, padding): if padding == "SAME": out_shape = _ceil_divide(in_shape, window_strides) - pad_sizes = onp.maximum(0, (out_shape - 1) * window_strides + + pad_sizes = np.maximum(0, (out_shape - 1) * window_strides + window_shape - in_shape) return [(pad_size // 2, pad_size - pad_size // 2) for pad_size in pad_sizes] elif padding == "VALID": @@ -142,28 +142,28 @@ def padtype_to_pads(in_shape, window_shape, window_strides, padding): def conv_general_shape_tuple(lhs_shape, rhs_shape, window_strides, padding, dimension_numbers): lhs_perm, rhs_perm, out_perm = conv_general_permutations(dimension_numbers) - lhs_trans = onp.take(lhs_shape, lhs_perm) - rhs_trans = onp.take(rhs_shape, rhs_perm) + lhs_trans = np.take(lhs_shape, lhs_perm) + rhs_trans = np.take(rhs_shape, rhs_perm) out_trans = conv_shape_tuple(lhs_trans, rhs_trans, window_strides, padding) - return tuple(onp.take(out_trans, onp.argsort(out_perm))) + return tuple(np.take(out_trans, np.argsort(out_perm))) def conv_transpose_shape_tuple(lhs_shape, rhs_shape, window_strides, padding, dimension_numbers): lhs_perm, rhs_perm, out_perm = conv_general_permutations(dimension_numbers) - lhs_trans = onp.take(lhs_shape, lhs_perm) - rhs_trans = onp.take(rhs_shape, rhs_perm) + lhs_trans = np.take(lhs_shape, lhs_perm) + rhs_trans = np.take(rhs_shape, rhs_perm) if isinstance(padding, str): padding = [_conv_transpose_padding(k, s, padding) for k,s in zip(rhs_trans[2:], window_strides)] - padding = list(map(onp.sum, padding)) + padding = list(map(np.sum, padding)) unpad_out_space = [(i-1) * s - k + 2 for i, k, s in zip(lhs_trans[2:], rhs_trans[2:], window_strides)] - out_space = onp.sum([unpad_out_space, padding], axis=0).tolist() + out_space = np.sum([unpad_out_space, padding], axis=0).tolist() out_trans = tuple((lhs_trans[0], rhs_trans[0]) + tuple(out_space)) - return tuple(onp.take(out_trans, onp.argsort(out_perm))) + return tuple(np.take(out_trans, np.argsort(out_perm))) def reduce_window_shape_tuple(operand_shape, window_dimensions, window_strides, @@ -171,9 +171,9 @@ def reduce_window_shape_tuple(operand_shape, window_dimensions, window_strides, window_dimensions = (1,) + window_dimensions + (1,) window_strides = (1,) + window_strides + (1,) pads = padtype_to_pads(operand_shape, window_dimensions, window_strides, padding) - operand_padded = onp.add(operand_shape, onp.add(*zip(*pads))) - t = onp.floor_divide( - onp.subtract(operand_padded, window_dimensions), window_strides) + 1 + operand_padded = np.add(operand_shape, np.add(*zip(*pads))) + t = np.floor_divide( + np.subtract(operand_padded, window_dimensions), window_strides) + 1 return tuple(t) @@ -222,7 +222,7 @@ def conv_transpose(lhs, rhs, strides, padding, else: raise ValueError('No 4+ dimensional dimension_number defaults.') dn = conv_dimension_numbers(lhs.shape, rhs.shape, dimension_numbers) - k_shape = onp.take(rhs.shape, dn.rhs_spec) + k_shape = np.take(rhs.shape, dn.rhs_spec) k_sdims = k_shape[2:] # Calculate correct output shape given padding and strides. pads: Union[str, Sequence[Tuple[int, int]]] @@ -236,8 +236,8 @@ def conv_transpose(lhs, rhs, strides, padding, pads = padding if transpose_kernel: # flip spatial dims and swap input / output channel axes - rhs = _flip_axes(rhs, onp.array(dn.rhs_spec)[2:]) - rhs = onp.swapaxes(rhs, dn.rhs_spec[0], dn.rhs_spec[1]) + rhs = _flip_axes(rhs, np.array(dn.rhs_spec)[2:]) + rhs = np.swapaxes(rhs, dn.rhs_spec[0], dn.rhs_spec[1]) return conv_general_dilated(lhs, rhs, one, pads, strides, rhs_dilation, dn) @@ -379,7 +379,7 @@ def conv_general_dilated(lhs, rhs, window_strides, padding, output_shape, lhs_di def _ceil_divide(x1, x2): - return -onp.floor_divide(onp.negative(x1), x2) + return -np.floor_divide(np.negative(x1), x2) def _conv_transpose_padding(k, s, padding): @@ -388,7 +388,7 @@ def _conv_transpose_padding(k, s, padding): if s > k - 1: pad_a = k - 1 else: - pad_a = int(onp.ceil(pad_len / 2)) + pad_a = int(np.ceil(pad_len / 2)) elif padding == 'VALID': pad_len = k + s - 2 + _max(k - s, 0) pad_a = k - 1 diff --git a/tf_helpers/stax.py b/tf_helpers/stax.py index 8da049b4..a05362f9 100644 --- a/tf_helpers/stax.py +++ b/tf_helpers/stax.py @@ -132,7 +132,7 @@ def PoolingLayer(window_shape, strides=None, padding='VALID', spec=None): for i in (batch_dim, channel_dim): window_shape = window_shape[:i] + (1,) + window_shape[i:] strides = strides[:i] + (1,) + strides[i:] - + window_shape = window_shape strides = strides @@ -148,7 +148,7 @@ def init_fun(rng, input_shape): strides, padding) return tfnp.zeros(out_shape), () def apply_fun(params, inputs, **kwargs): - inputs = onp.moveaxis(inputs, (batch_dim, channel_dim), \ + inputs = np.moveaxis(inputs, (batch_dim, channel_dim), \ (0, dim + 1)) output = lax.reduce_window(inputs, init_val, reducer, window_shape, strides, padding) From d7304566dee1edf33a49ccbba40afae4e741e9df Mon Sep 17 00:00:00 2001 From: DarrenZhang01 <18633059886@163.com> Date: Wed, 26 Aug 2020 16:56:05 -0400 Subject: [PATCH 26/28] Remove the unused lines and unused `moveaxis`. --- tf_helpers/stax.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/tf_helpers/stax.py b/tf_helpers/stax.py index a05362f9..a633e09b 100644 --- a/tf_helpers/stax.py +++ b/tf_helpers/stax.py @@ -148,13 +148,9 @@ def init_fun(rng, input_shape): strides, padding) return tfnp.zeros(out_shape), () def apply_fun(params, inputs, **kwargs): - inputs = np.moveaxis(inputs, (batch_dim, channel_dim), \ - (0, dim + 1)) output = lax.reduce_window(inputs, init_val, reducer, window_shape, strides, padding) return rescale(out, inputs, spec) if rescale else out - # return output - return tfnp.array(output) return init_fun, apply_fun return PoolingLayer MaxPool = _pooling_layer(tfnp.max, -tfnp.inf) From 588b7483caf9b84deaa8467e365738c98eb81618 Mon Sep 17 00:00:00 2001 From: DarrenZhang01 <18633059886@163.com> Date: Wed, 26 Aug 2020 17:04:16 -0400 Subject: [PATCH 27/28] Remove the extra 2 dimensions in the output of TF `reduce_window`. --- tf_helpers/lax.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tf_helpers/lax.py b/tf_helpers/lax.py index ee78144c..4a615ec3 100644 --- a/tf_helpers/lax.py +++ b/tf_helpers/lax.py @@ -311,7 +311,7 @@ def reduce_window(inputs, init_value, reducer, window_dimensions, strides, # input shape is already finished in apply_fun of Pooling in stax. pooling = "AVG" if reducer == np.add else "MAX" output = nn.pool(inputs, window_dimensions, pooling, strides, padding) - return np.asarray(output) + return np.squeeze(np.asarray(output), axis=(0, output.ndim - 1)) # TOTO (Zhibo Zhang): Expand the test cases of general convolution and revise From ab173740afbde13aeba19e61ad3fc4d84e8dbe0b Mon Sep 17 00:00:00 2001 From: DarrenZhang01 <18633059886@163.com> Date: Wed, 26 Aug 2020 17:06:58 -0400 Subject: [PATCH 28/28] Move `np.asarray` wrapper to TF `pool`. --- tf_helpers/lax.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tf_helpers/lax.py b/tf_helpers/lax.py index 4a615ec3..483f05a5 100644 --- a/tf_helpers/lax.py +++ b/tf_helpers/lax.py @@ -310,8 +310,8 @@ def reduce_window(inputs, init_value, reducer, window_dimensions, strides, # input is already of default data format - "N...C". The adjustments of the # input shape is already finished in apply_fun of Pooling in stax. pooling = "AVG" if reducer == np.add else "MAX" - output = nn.pool(inputs, window_dimensions, pooling, strides, padding) - return np.squeeze(np.asarray(output), axis=(0, output.ndim - 1)) + output = np.asarray(nn.pool(inputs, window_dimensions, pooling, strides, padding)) + return np.squeeze(output, axis=(0, output.ndim - 1)) # TOTO (Zhibo Zhang): Expand the test cases of general convolution and revise