Skip to content

Commit

Permalink
Low rank locally connected layer, data, and model used for the paper.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 295752649
  • Loading branch information
gamaleldin authored and copybara-github committed Feb 18, 2020
1 parent ac96079 commit fdc4c03
Show file tree
Hide file tree
Showing 11 changed files with 1,861 additions and 0 deletions.
14 changes: 14 additions & 0 deletions low_rank_local_connectivity/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
Revisiting Spatial Invariance with Low-Rank Local Connectivity
https://arxiv.org/abs/2002.02959

Work in progress

This is the directory for the low-rank locally connected layer and experiments.

We develop a low-rank locally connected (LRLC) layer that can parametrically
adjust the degree of spatial invariance. This layer is one particular
method to relax spatial invariance by reducing weight sharing. Rather than
learning a single filter bank to apply at all positions, as in a convolutional
layer, or different filter banks, as in a locally connected layer,
the LRLC layer learns a set of K filter banks, which are linearly combined using
K combining weights per spatial position.
15 changes: 15 additions & 0 deletions low_rank_local_connectivity/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# coding=utf-8
# Copyright 2020 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

289 changes: 289 additions & 0 deletions low_rank_local_connectivity/data_provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,289 @@
# coding=utf-8
# Copyright 2020 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Data provider with an argument to control data augmentation."""

from __future__ import absolute_import
from __future__ import division

from __future__ import print_function


import functools
import tensorflow as tf
import tensorflow_datasets as tfds

from low_rank_local_connectivity import utils


def extract_data(data, preprocess_image):
"""Extracts image, label and create a mask."""
image = data["image"]
# Reserve label 0 for background
label = tf.cast(data["label"], dtype=tf.int32)
# Create a mask variable to track the real vs padded data in the last batch.
mask = 1.
image = preprocess_image(image)
return image, label, mask


def construct_iterator(dataset_builder,
split,
preprocess_fn,
batch_size,
is_training):
"""Constructs data iterator.
Args:
dataset_builder: tensorflow_datasets data builder.
split: tensorflow_datasets data split.
preprocess_fn: Function that preprocess each data example.
batch_size: (Integer) Batch size.
is_training: (boolean) Whether training or inference mode.
Returns:
Data iterator.
"""
dataset = dataset_builder.as_dataset(split=split, shuffle_files=True)
dataset = dataset.map(preprocess_fn,
num_parallel_calls=tf.data.experimental.AUTOTUNE)
if is_training:
# 4096 is ~0.625 GB of RAM. Reduce if memory issues encountered.
dataset = dataset.shuffle(buffer_size=4096)
dataset = dataset.repeat(-1 if is_training else 1)
dataset = dataset.batch(batch_size, drop_remainder=is_training)

if not is_training:
# Pad the remainder of the last batch to make batch size fixed.
dataset = utils.pad_to_batch(dataset, batch_size)

dataset = dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
return dataset.make_one_shot_iterator()


class MNISTDataProvider(object):
"""MNIST Data Provider.
Attributes:
images: (4-D tensor) Images of shape (batch, height, width, channels).
labels: (1-D tensor) Data labels of size (batch,).
mask: (1-D boolean tensor) Data mask. Used when data is not repeated to
indicate the fraction of the batch with true data in the final batch.
num_classes: (Integer) Number of classes in the dataset.
num_examples: (Integer) Number of examples in the dataset.
class_names: (List of Strings) MNIST id for class labels.
num_channels: (integer) Number of image color channels.
image_size: (Integer) Size of the image.
iterator: Tensorflow data iterator.
"""

def __init__(self,
subset,
batch_size,
is_training,
data_dir=None):
dataset_builder = tfds.builder("mnist", data_dir=data_dir)
dataset_builder.download_and_prepare(download_dir=data_dir)
self.image_size = 28
if subset == "train":
split = tfds.core.ReadInstruction("train", from_=8, to=100, unit="%")
elif subset == "valid":
split = tfds.core.ReadInstruction("train", from_=0, to=8, unit="%")
elif subset == "test":
split = tfds.Split.TEST
else:
raise ValueError("subset %s is undefined " % subset)
self.num_channels = 1
iterator = construct_iterator(
dataset_builder, split, self._preprocess_fn(), batch_size, is_training)

info = dataset_builder.info
self.iterator = iterator
self.images, self.labels, self.mask = iterator.get_next()
self.num_classes = info.features["label"].num_classes
self.class_names = info.features["label"].names
self.num_examples = info.splits[split].num_examples

def _preprocess_fn(self):
"""Preprocessing function."""
image_size = self.image_size
def preprocess_image(image):
"""Preprocessing."""
image = tf.cast(image, dtype=tf.float32)
image = image / 255.
image = 2 * image - 1
image = tf.image.resize_image_with_crop_or_pad(
image, image_size, image_size)
return image

preprocess_fn = functools.partial(extract_data,
preprocess_image=preprocess_image)
return preprocess_fn


class CIFAR10DataProvider(object):
"""CIFAR10 Data Provider.
Attributes:
images: (4-D tensor) Images of shape (batch, height, width, channels).
labels: (1-D tensor) Data labels of size (batch,).
mask: (1-D boolean tensor) Data mask. Used when data is not repeated to
indicate the fraction of the batch with true data in the final batch.
num_classes: (Integer) Number of classes in the dataset.
num_examples: (Integer) Number of examples in the dataset.
class_names: (List of Strings) CIFAR10 id for class labels.
num_channels: (integer) Number of image color channels.
image_size: (Integer) Size of the image.
iterator: Tensorflow data iterator.
"""

def __init__(self,
subset,
batch_size,
is_training,
data_dir=None):
dataset_builder = tfds.builder("cifar10", data_dir=data_dir)
dataset_builder.download_and_prepare(download_dir=data_dir)
self.image_size = 32

if subset == "train":
split = tfds.core.ReadInstruction("train", from_=10, to=100, unit="%")
elif subset == "valid":
split = tfds.core.ReadInstruction("train", from_=0, to=10, unit="%")
elif subset == "test":
split = tfds.Split.TEST
else:
raise ValueError("subset %s is undefined " % subset)
self.num_channels = 3
iterator = construct_iterator(
dataset_builder, split, self._preprocess_fn(), batch_size, is_training)
info = dataset_builder.info
self.iterator = iterator
self.images, self.labels, self.mask = iterator.get_next()
self.num_classes = info.features["label"].num_classes
self.class_names = info.features["label"].names
self.num_examples = info.splits[split].num_examples

def _preprocess_fn(self):
"""Preprocessing function."""
image_size = self.image_size
def preprocess_image(image):
"""Preprocessing."""
image = tf.image.resize_image_with_crop_or_pad(
image, image_size, image_size)
return image

preprocess_fn = functools.partial(extract_data,
preprocess_image=preprocess_image)
return preprocess_fn


def extract_data_celeba(data, preprocess_image, attribute="Male"):
"""Extracts image, label and create a mask (used by CelebA data provider)."""
image = data["image"]
# Reserve label 0 for background
label = tf.cast(data["attributes"][attribute], dtype=tf.int32)
# Create a mask variable to track the real vs padded data in the last batch.
mask = 1.
image = preprocess_image(image)
return image, label, mask


class CelebADataProvider(object):
"""CelebA Data Provider.
Attributes:
images: (4-D tensor) Images of shape (batch, height, width, channels).
labels: (1-D tensor) Data labels of size (batch,).
mask: (1-D boolean tensor) Data mask. Used when data is not repeated to
indicate the fraction of the batch with true data in the final batch.
num_classes: (integer) Number of classes in the dataset.
num_examples: (integer) Number of examples in the dataset.
num_channels: (integer) Number of image color channels.
image_size: (Integer) Size of the image.
iterator: Tensorflow data iterator.
class_names: (List of strings) Name of classes in the order of the labels.
"""

def __init__(self,
subset,
batch_size,
is_training,
data_dir=None):
self.image_size = 32

dataset_builder = tfds.builder("celeb_a",
data_dir=data_dir)
dataset_builder.download_and_prepare(download_dir=data_dir)
if subset == "train":
split = tfds.Split.TRAIN

elif subset == "valid":
split = tfds.Split.VALIDATION

elif subset == "test":
split = tfds.Split.TEST

else:
raise ValueError(
"subset %s is undefined for the dataset" % subset)
self.num_channels = 3
iterator = construct_iterator(
dataset_builder, split, self._preprocess_fn(), batch_size, is_training)
info = dataset_builder.info
self.iterator = iterator
self.images, self.labels, self.mask = iterator.get_next()
self.num_classes = 2
self.class_names = ["Female", "Male"]
self.num_examples = info.splits[split].num_examples

def _preprocess_fn(self):
"""Preprocessing."""
crop = True
image_size = self.image_size
def preprocess_image(image):
"""Preprocesses the given image.
Args:
image: Tensor `image` representing a single image example of
arbitrary size.
Returns:
Preprocessed image.
"""
image = tf.image.convert_image_dtype(image, dtype=tf.float32)
if crop:
image = tf.image.crop_to_bounding_box(image, 40, 20, 218 - 80, 178 - 40)

image = tf.image.resize_bicubic([image], [image_size, image_size])[0]
return image

preprocess_fn = functools.partial(extract_data_celeba,
preprocess_image=preprocess_image,
attribute="Male")
return preprocess_fn


# ===== Function that provides data. ======
_DATASETS = {
"cifar10": CIFAR10DataProvider,
"mnist": MNISTDataProvider,
"celeba32": CelebADataProvider,
}


def get_data_provider(dataset_name):
"""Returns dataset by name."""
return _DATASETS[dataset_name]
71 changes: 71 additions & 0 deletions low_rank_local_connectivity/data_provider_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# coding=utf-8
# Copyright 2020 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tests data provider."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from absl.testing import parameterized
import tensorflow as tf

from low_rank_local_connectivity.data_provider import get_data_provider

_IMAGE_SHAPE_DICT = {
"mnist": (28, 28, 1),
"cifar10": (32, 32, 3),
"celeba32": (32, 32, 3),
}


def _get_test_cases():
"""Provides test cases."""
is_training = [True, False]
subset = ["train", "valid", "test"]
dataset_name = _IMAGE_SHAPE_DICT.keys()
i = 0
cases = []
for d in dataset_name:
for s in subset:
for t in is_training:
cases.append(("case_%d" % i, d, s, t))
i += 1
return tuple(cases)


class DataProviderTest(tf.test.TestCase, parameterized.TestCase):

def test_import(self):
self.assertIsNotNone(get_data_provider)

@parameterized.named_parameters(*_get_test_cases())
def test_dataset(self, dataset_name, subset, is_training):
batch_size = 1
image_shape = _IMAGE_SHAPE_DICT[dataset_name]
dataset = get_data_provider(dataset_name)(
subset=subset,
batch_size=batch_size,
is_training=is_training)
images, labels = dataset.images, dataset.labels

im, l = self.evaluate((images, labels))
self.assertEqual(im.shape, (batch_size,) + image_shape)

self.assertEqual(l.shape, (batch_size,))


if __name__ == "__main__":
tf.test.main()
Loading

0 comments on commit fdc4c03

Please sign in to comment.