Skip to content

Commit

Permalink
Project import generated by Copybara.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 363220025
  • Loading branch information
ModelSearch authored and hanna-maz committed Mar 17, 2021
1 parent 6cdf17e commit 8c5eed4
Show file tree
Hide file tree
Showing 18 changed files with 346 additions and 4 deletions.
35 changes: 33 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,41 @@ The search will be performed according to the default specification. That can be
For more details about the fields and if you want to create your own specification, you
can look at: `model_search/proto/phoenix_spec.proto`.

Now, what if you don't have a csv with the features? The next section shows
### Image data example
Below is an example of binary classification for images.

```python
import model_search
from model_search import constants
from model_search import single_trainer
from model_search.data import image_data

trainer = single_trainer.SingleTrainer(
data=image_data.Provider(
input_dir="model_search/data/testdata/images"
image_height=100,
image_width=100,
eval_fraction=0.2),
spec=constants.DEFAULT_CNN)

trainer.try_models(
number_models=200,
train_steps=1000,
eval_steps=100,
root_dir="/tmp/run_example",
batch_size=32,
experiment_name="example",
experiment_owner="model_search_user")
```
The api above follows the same input fields as `tf.keras.preprocessing.image_dataset_from_directory`.

The search will be performed according to the default specification. That can be found in:
`model_search/configs/cnn_config.pbtxt`.

Now, what if you don't have a csv with the features or images? The next section shows
how to run without a csv.

## Non-csv data
## Non-csv, Non-image data
To run with non-csv data, you will have to implement a class inherited from the abstract
class `model_search.data.Provider`. This enables us to define our own
`input_fn` and hence customize the feature columns and the task (i.e., the number
Expand Down
2 changes: 1 addition & 1 deletion model_search/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def requires_hparams(self):
return None


# NEXT ID: 146
# NEXT ID: 147
# NEXT EXPERIMENTAL ID: 10017 (experiment id starts at 10,001)
register_block = functools.partial(registry.register, base=Block)

Expand Down
2 changes: 1 addition & 1 deletion model_search/blocks_builder_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def test_naming_of_tunable(self):
for idx, name in enumerate(names):
for idx2, name2 in enumerate(names):
if idx != idx2:
self.assertNotStartsWith(name, name2)
self.assertNotStartsWith(name.name, name2.name)


if __name__ == "__main__":
Expand Down
51 changes: 51 additions & 0 deletions model_search/data/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,57 @@ model_search_oss_binary(
dataset_dep = ":csv_data_for_binary",
)

py_library(
name = "image_data",
srcs = ["image_data.py"],
srcs_version = "PY3",
deps = [":data"],
)

py_library(
name = "image_data_for_binary",
srcs = ["image_data_for_binary.py"],
srcs_version = "PY3",
deps = [":data"],
)

py_test(
name = "image_data_test",
srcs = ["image_data_test.py"],
data = [
"//model_search/configs:phoenix_configs",
"//model_search/data/testdata:image_data",
],
python_version = "PY3",
srcs_version = "PY3",
deps = [
":image_data",
"@absl_py//absl/testing:absltest",
"@absl_py//absl/testing:parameterized",
"//model_search:constants",
"//model_search:single_trainer",
],
)

model_search_oss_test(
name = "image_data_for_binary_test",
dataset_dep = ":image_data_for_binary",
extra_args = [
"--input_dir=$${TEST_SRCDIR}/model_search/data/testdata/images",
"--image_height=100",
"--image_width=100",
],
problem_type = "cnn",
test_data = [
"//model_search/data/testdata:image_data",
],
)

model_search_oss_binary(
name = "image_data_binary",
dataset_dep = ":image_data_for_binary",
)

py_library(
name = "data",
srcs = ["data.py"],
Expand Down
90 changes: 90 additions & 0 deletions model_search/data/image_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Simple csv reader for small classification problems."""

from model_search.data import data
import tensorflow.compat.v2 as tf


class Provider(data.Provider):
"""A csv data provider."""

def __init__(self, input_dir, image_width, image_height, eval_fraction):
self._input_dir = input_dir
self._image_width = image_width
self._image_height = image_height
self._eval_fraction = eval_fraction

def get_input_fn(self, hparams, mode, batch_size):
"""See `data.Provider` get_input_fn."""
del hparams

def input_fn(params=None):
del params
split = ('training'
if mode == tf.estimator.ModeKeys.TRAIN else 'validation')

dataset = tf.keras.preprocessing.image_dataset_from_directory(
directory=self._input_dir,
labels='inferred',
label_mode='binary',
class_names=None,
color_mode='rgb',
batch_size=batch_size,
image_size=(self._image_height, self._image_width),
shuffle=True,
seed=73,
validation_split=self._eval_fraction,
subset=split,
interpolation='bilinear',
follow_links=False)

if mode == tf.estimator.ModeKeys.TRAIN:
dataset = dataset.cache().prefetch(
buffer_size=tf.data.experimental.AUTOTUNE)

return dataset

return input_fn

def get_serving_input_fn(self, hparams):
"""Returns an `input_fn` for serving in an exported SavedModel.
Args:
hparams: tf.HParams object.
Returns:
Returns an `input_fn` that takes no arguments and returns a
`ServingInputReceiver`.
"""
tf.compat.v1.disable_eager_execution()
features = {
'image':
tf.compat.v1.placeholder(
tf.float32, [None, self._image_height, self._image_width, 3],
'image')
}
return tf.estimator.export.build_raw_serving_input_receiver_fn(
features=features)

def number_of_classes(self):
return 2

def get_feature_columns(self):
"""Returns feature columns."""
feature_columns = [
tf.feature_column.numeric_column(
key='image', shape=(self._image_height, self._image_width, 3))
]
return feature_columns
111 changes: 111 additions & 0 deletions model_search/data/image_data_for_binary.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Simple csv reader for small classification problems."""

from absl import flags

from model_search.data import data
import tensorflow.compat.v2 as tf

flags.DEFINE_string(
'input_dir', '', 'The path containing the input data. Should provide a dir '
'that has 0 and 1 as subdirs.')

flags.DEFINE_float('eval_fraction', 0.2,
'The amount of data (fraction) to hold for evaluation.')

flags.DEFINE_integer('image_height', 320,
'The height (dimension) of the image.')

flags.DEFINE_integer('image_width', 240, 'The width (dimension) of the image.')

FLAGS = flags.FLAGS


@data.register_provider(lookup_name='image_data_provider', init_args={})
class Provider(data.Provider):
"""A csv data provider."""

def __init__(self):
self._input_dir = FLAGS.input_dir
self._image_width = FLAGS.image_width
self._image_height = FLAGS.image_height
self._eval_fraction = FLAGS.eval_fraction

if '${TEST_SRCDIR}' in self._input_dir:
self._input_dir = self._input_dir.replace('${TEST_SRCDIR}',
FLAGS.test_srcdir)

def get_input_fn(self, hparams, mode, batch_size):
"""See `data.Provider` get_input_fn."""
del hparams

def input_fn(params=None):
del params
split = ('training'
if mode == tf.estimator.ModeKeys.TRAIN else 'validation')

dataset = tf.keras.preprocessing.image_dataset_from_directory(
directory=self._input_dir,
labels='inferred',
label_mode='binary',
class_names=None,
color_mode='rgb',
batch_size=batch_size,
image_size=(self._image_height, self._image_width),
shuffle=True,
seed=73,
validation_split=self._eval_fraction,
subset=split,
interpolation='bilinear',
follow_links=False)

if mode == tf.estimator.ModeKeys.TRAIN:
dataset = dataset.cache().prefetch(
buffer_size=tf.data.experimental.AUTOTUNE)

return dataset

return input_fn

def get_serving_input_fn(self, hparams):
"""Returns an `input_fn` for serving in an exported SavedModel.
Args:
hparams: tf.HParams object.
Returns:
Returns an `input_fn` that takes no arguments and returns a
`ServingInputReceiver`.
"""
tf.compat.v1.disable_eager_execution()
features = {
'image':
tf.compat.v1.placeholder(
tf.float32, [None, self._image_height, self._image_width, 3],
'image')
}
return tf.estimator.export.build_raw_serving_input_receiver_fn(
features=features)

def number_of_classes(self):
return 2

def get_feature_columns(self):
"""Returns feature columns."""
feature_columns = [
tf.feature_column.numeric_column(
key='image', shape=(self._image_height, self._image_width, 3))
]
return feature_columns
52 changes: 52 additions & 0 deletions model_search/data/image_data_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for model_search.single_trainer."""

import os
from absl import flags
from absl.testing import absltest
from model_search import constants
from model_search import single_trainer
from model_search.data import image_data

FLAGS = flags.FLAGS


class SingleTrainerTest(absltest.TestCase):

def test_try_models(self):
# Test is source code is deployed in FLAGS.test_srcdir
spec_path = os.path.join(FLAGS.test_srcdir, constants.DEFAULT_CNN)
trainer = single_trainer.SingleTrainer(
data=image_data.Provider(
input_dir=os.path.join(
FLAGS.test_srcdir,
"model_search/model_search/data/testdata/images"),
image_width=100,
image_height=100,
eval_fraction=0.2),
spec=spec_path)

trainer.try_models(
number_models=7,
train_steps=10,
eval_steps=10,
root_dir=FLAGS.test_tmpdir,
batch_size=2,
experiment_name="test",
experiment_owner="test")


if __name__ == "__main__":
absltest.main()
7 changes: 7 additions & 0 deletions model_search/data/testdata/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,10 @@ filegroup(
"csv_random_data.csv",
],
)

filegroup(
name = "image_data",
srcs = glob([
"**/*.png",
]),
)
Binary file added model_search/data/testdata/images/0/1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added model_search/data/testdata/images/0/2.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added model_search/data/testdata/images/0/3.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added model_search/data/testdata/images/0/4.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added model_search/data/testdata/images/0/5.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added model_search/data/testdata/images/1/10.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added model_search/data/testdata/images/1/6.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added model_search/data/testdata/images/1/7.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added model_search/data/testdata/images/1/8.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added model_search/data/testdata/images/1/9.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.

0 comments on commit 8c5eed4

Please sign in to comment.