forked from google-research/google-research
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Low rank locally connected layer, data, and model used for the paper.
PiperOrigin-RevId: 295752649
- Loading branch information
1 parent
ac96079
commit fdc4c03
Showing
11 changed files
with
1,861 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
Revisiting Spatial Invariance with Low-Rank Local Connectivity | ||
https://arxiv.org/abs/2002.02959 | ||
|
||
Work in progress | ||
|
||
This is the directory for the low-rank locally connected layer and experiments. | ||
|
||
We develop a low-rank locally connected (LRLC) layer that can parametrically | ||
adjust the degree of spatial invariance. This layer is one particular | ||
method to relax spatial invariance by reducing weight sharing. Rather than | ||
learning a single filter bank to apply at all positions, as in a convolutional | ||
layer, or different filter banks, as in a locally connected layer, | ||
the LRLC layer learns a set of K filter banks, which are linearly combined using | ||
K combining weights per spatial position. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
# coding=utf-8 | ||
# Copyright 2020 The Google Research Authors. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,289 @@ | ||
# coding=utf-8 | ||
# Copyright 2020 The Google Research Authors. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
"""Data provider with an argument to control data augmentation.""" | ||
|
||
from __future__ import absolute_import | ||
from __future__ import division | ||
|
||
from __future__ import print_function | ||
|
||
|
||
import functools | ||
import tensorflow as tf | ||
import tensorflow_datasets as tfds | ||
|
||
from low_rank_local_connectivity import utils | ||
|
||
|
||
def extract_data(data, preprocess_image): | ||
"""Extracts image, label and create a mask.""" | ||
image = data["image"] | ||
# Reserve label 0 for background | ||
label = tf.cast(data["label"], dtype=tf.int32) | ||
# Create a mask variable to track the real vs padded data in the last batch. | ||
mask = 1. | ||
image = preprocess_image(image) | ||
return image, label, mask | ||
|
||
|
||
def construct_iterator(dataset_builder, | ||
split, | ||
preprocess_fn, | ||
batch_size, | ||
is_training): | ||
"""Constructs data iterator. | ||
Args: | ||
dataset_builder: tensorflow_datasets data builder. | ||
split: tensorflow_datasets data split. | ||
preprocess_fn: Function that preprocess each data example. | ||
batch_size: (Integer) Batch size. | ||
is_training: (boolean) Whether training or inference mode. | ||
Returns: | ||
Data iterator. | ||
""" | ||
dataset = dataset_builder.as_dataset(split=split, shuffle_files=True) | ||
dataset = dataset.map(preprocess_fn, | ||
num_parallel_calls=tf.data.experimental.AUTOTUNE) | ||
if is_training: | ||
# 4096 is ~0.625 GB of RAM. Reduce if memory issues encountered. | ||
dataset = dataset.shuffle(buffer_size=4096) | ||
dataset = dataset.repeat(-1 if is_training else 1) | ||
dataset = dataset.batch(batch_size, drop_remainder=is_training) | ||
|
||
if not is_training: | ||
# Pad the remainder of the last batch to make batch size fixed. | ||
dataset = utils.pad_to_batch(dataset, batch_size) | ||
|
||
dataset = dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE) | ||
return dataset.make_one_shot_iterator() | ||
|
||
|
||
class MNISTDataProvider(object): | ||
"""MNIST Data Provider. | ||
Attributes: | ||
images: (4-D tensor) Images of shape (batch, height, width, channels). | ||
labels: (1-D tensor) Data labels of size (batch,). | ||
mask: (1-D boolean tensor) Data mask. Used when data is not repeated to | ||
indicate the fraction of the batch with true data in the final batch. | ||
num_classes: (Integer) Number of classes in the dataset. | ||
num_examples: (Integer) Number of examples in the dataset. | ||
class_names: (List of Strings) MNIST id for class labels. | ||
num_channels: (integer) Number of image color channels. | ||
image_size: (Integer) Size of the image. | ||
iterator: Tensorflow data iterator. | ||
""" | ||
|
||
def __init__(self, | ||
subset, | ||
batch_size, | ||
is_training, | ||
data_dir=None): | ||
dataset_builder = tfds.builder("mnist", data_dir=data_dir) | ||
dataset_builder.download_and_prepare(download_dir=data_dir) | ||
self.image_size = 28 | ||
if subset == "train": | ||
split = tfds.core.ReadInstruction("train", from_=8, to=100, unit="%") | ||
elif subset == "valid": | ||
split = tfds.core.ReadInstruction("train", from_=0, to=8, unit="%") | ||
elif subset == "test": | ||
split = tfds.Split.TEST | ||
else: | ||
raise ValueError("subset %s is undefined " % subset) | ||
self.num_channels = 1 | ||
iterator = construct_iterator( | ||
dataset_builder, split, self._preprocess_fn(), batch_size, is_training) | ||
|
||
info = dataset_builder.info | ||
self.iterator = iterator | ||
self.images, self.labels, self.mask = iterator.get_next() | ||
self.num_classes = info.features["label"].num_classes | ||
self.class_names = info.features["label"].names | ||
self.num_examples = info.splits[split].num_examples | ||
|
||
def _preprocess_fn(self): | ||
"""Preprocessing function.""" | ||
image_size = self.image_size | ||
def preprocess_image(image): | ||
"""Preprocessing.""" | ||
image = tf.cast(image, dtype=tf.float32) | ||
image = image / 255. | ||
image = 2 * image - 1 | ||
image = tf.image.resize_image_with_crop_or_pad( | ||
image, image_size, image_size) | ||
return image | ||
|
||
preprocess_fn = functools.partial(extract_data, | ||
preprocess_image=preprocess_image) | ||
return preprocess_fn | ||
|
||
|
||
class CIFAR10DataProvider(object): | ||
"""CIFAR10 Data Provider. | ||
Attributes: | ||
images: (4-D tensor) Images of shape (batch, height, width, channels). | ||
labels: (1-D tensor) Data labels of size (batch,). | ||
mask: (1-D boolean tensor) Data mask. Used when data is not repeated to | ||
indicate the fraction of the batch with true data in the final batch. | ||
num_classes: (Integer) Number of classes in the dataset. | ||
num_examples: (Integer) Number of examples in the dataset. | ||
class_names: (List of Strings) CIFAR10 id for class labels. | ||
num_channels: (integer) Number of image color channels. | ||
image_size: (Integer) Size of the image. | ||
iterator: Tensorflow data iterator. | ||
""" | ||
|
||
def __init__(self, | ||
subset, | ||
batch_size, | ||
is_training, | ||
data_dir=None): | ||
dataset_builder = tfds.builder("cifar10", data_dir=data_dir) | ||
dataset_builder.download_and_prepare(download_dir=data_dir) | ||
self.image_size = 32 | ||
|
||
if subset == "train": | ||
split = tfds.core.ReadInstruction("train", from_=10, to=100, unit="%") | ||
elif subset == "valid": | ||
split = tfds.core.ReadInstruction("train", from_=0, to=10, unit="%") | ||
elif subset == "test": | ||
split = tfds.Split.TEST | ||
else: | ||
raise ValueError("subset %s is undefined " % subset) | ||
self.num_channels = 3 | ||
iterator = construct_iterator( | ||
dataset_builder, split, self._preprocess_fn(), batch_size, is_training) | ||
info = dataset_builder.info | ||
self.iterator = iterator | ||
self.images, self.labels, self.mask = iterator.get_next() | ||
self.num_classes = info.features["label"].num_classes | ||
self.class_names = info.features["label"].names | ||
self.num_examples = info.splits[split].num_examples | ||
|
||
def _preprocess_fn(self): | ||
"""Preprocessing function.""" | ||
image_size = self.image_size | ||
def preprocess_image(image): | ||
"""Preprocessing.""" | ||
image = tf.image.resize_image_with_crop_or_pad( | ||
image, image_size, image_size) | ||
return image | ||
|
||
preprocess_fn = functools.partial(extract_data, | ||
preprocess_image=preprocess_image) | ||
return preprocess_fn | ||
|
||
|
||
def extract_data_celeba(data, preprocess_image, attribute="Male"): | ||
"""Extracts image, label and create a mask (used by CelebA data provider).""" | ||
image = data["image"] | ||
# Reserve label 0 for background | ||
label = tf.cast(data["attributes"][attribute], dtype=tf.int32) | ||
# Create a mask variable to track the real vs padded data in the last batch. | ||
mask = 1. | ||
image = preprocess_image(image) | ||
return image, label, mask | ||
|
||
|
||
class CelebADataProvider(object): | ||
"""CelebA Data Provider. | ||
Attributes: | ||
images: (4-D tensor) Images of shape (batch, height, width, channels). | ||
labels: (1-D tensor) Data labels of size (batch,). | ||
mask: (1-D boolean tensor) Data mask. Used when data is not repeated to | ||
indicate the fraction of the batch with true data in the final batch. | ||
num_classes: (integer) Number of classes in the dataset. | ||
num_examples: (integer) Number of examples in the dataset. | ||
num_channels: (integer) Number of image color channels. | ||
image_size: (Integer) Size of the image. | ||
iterator: Tensorflow data iterator. | ||
class_names: (List of strings) Name of classes in the order of the labels. | ||
""" | ||
|
||
def __init__(self, | ||
subset, | ||
batch_size, | ||
is_training, | ||
data_dir=None): | ||
self.image_size = 32 | ||
|
||
dataset_builder = tfds.builder("celeb_a", | ||
data_dir=data_dir) | ||
dataset_builder.download_and_prepare(download_dir=data_dir) | ||
if subset == "train": | ||
split = tfds.Split.TRAIN | ||
|
||
elif subset == "valid": | ||
split = tfds.Split.VALIDATION | ||
|
||
elif subset == "test": | ||
split = tfds.Split.TEST | ||
|
||
else: | ||
raise ValueError( | ||
"subset %s is undefined for the dataset" % subset) | ||
self.num_channels = 3 | ||
iterator = construct_iterator( | ||
dataset_builder, split, self._preprocess_fn(), batch_size, is_training) | ||
info = dataset_builder.info | ||
self.iterator = iterator | ||
self.images, self.labels, self.mask = iterator.get_next() | ||
self.num_classes = 2 | ||
self.class_names = ["Female", "Male"] | ||
self.num_examples = info.splits[split].num_examples | ||
|
||
def _preprocess_fn(self): | ||
"""Preprocessing.""" | ||
crop = True | ||
image_size = self.image_size | ||
def preprocess_image(image): | ||
"""Preprocesses the given image. | ||
Args: | ||
image: Tensor `image` representing a single image example of | ||
arbitrary size. | ||
Returns: | ||
Preprocessed image. | ||
""" | ||
image = tf.image.convert_image_dtype(image, dtype=tf.float32) | ||
if crop: | ||
image = tf.image.crop_to_bounding_box(image, 40, 20, 218 - 80, 178 - 40) | ||
|
||
image = tf.image.resize_bicubic([image], [image_size, image_size])[0] | ||
return image | ||
|
||
preprocess_fn = functools.partial(extract_data_celeba, | ||
preprocess_image=preprocess_image, | ||
attribute="Male") | ||
return preprocess_fn | ||
|
||
|
||
# ===== Function that provides data. ====== | ||
_DATASETS = { | ||
"cifar10": CIFAR10DataProvider, | ||
"mnist": MNISTDataProvider, | ||
"celeba32": CelebADataProvider, | ||
} | ||
|
||
|
||
def get_data_provider(dataset_name): | ||
"""Returns dataset by name.""" | ||
return _DATASETS[dataset_name] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
# coding=utf-8 | ||
# Copyright 2020 The Google Research Authors. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
"""Tests data provider.""" | ||
|
||
from __future__ import absolute_import | ||
from __future__ import division | ||
from __future__ import print_function | ||
|
||
from absl.testing import parameterized | ||
import tensorflow as tf | ||
|
||
from low_rank_local_connectivity.data_provider import get_data_provider | ||
|
||
_IMAGE_SHAPE_DICT = { | ||
"mnist": (28, 28, 1), | ||
"cifar10": (32, 32, 3), | ||
"celeba32": (32, 32, 3), | ||
} | ||
|
||
|
||
def _get_test_cases(): | ||
"""Provides test cases.""" | ||
is_training = [True, False] | ||
subset = ["train", "valid", "test"] | ||
dataset_name = _IMAGE_SHAPE_DICT.keys() | ||
i = 0 | ||
cases = [] | ||
for d in dataset_name: | ||
for s in subset: | ||
for t in is_training: | ||
cases.append(("case_%d" % i, d, s, t)) | ||
i += 1 | ||
return tuple(cases) | ||
|
||
|
||
class DataProviderTest(tf.test.TestCase, parameterized.TestCase): | ||
|
||
def test_import(self): | ||
self.assertIsNotNone(get_data_provider) | ||
|
||
@parameterized.named_parameters(*_get_test_cases()) | ||
def test_dataset(self, dataset_name, subset, is_training): | ||
batch_size = 1 | ||
image_shape = _IMAGE_SHAPE_DICT[dataset_name] | ||
dataset = get_data_provider(dataset_name)( | ||
subset=subset, | ||
batch_size=batch_size, | ||
is_training=is_training) | ||
images, labels = dataset.images, dataset.labels | ||
|
||
im, l = self.evaluate((images, labels)) | ||
self.assertEqual(im.shape, (batch_size,) + image_shape) | ||
|
||
self.assertEqual(l.shape, (batch_size,)) | ||
|
||
|
||
if __name__ == "__main__": | ||
tf.test.main() |
Oops, something went wrong.