Skip to content

Commit

Permalink
Add include_preprocessing to model exporters.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 488371678
  • Loading branch information
dzelle authored and tensorflower-gardener committed Nov 14, 2022
1 parent 1a6f7ad commit d7a9659
Showing 1 changed file with 11 additions and 4 deletions.
15 changes: 11 additions & 4 deletions tensorflow_gnn/runner/utils/model_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ class KerasModelExporter:
def __init__(self,
*,
output_names: Optional[Any] = None,
subdirectory: Optional[str] = None):
subdirectory: Optional[str] = None,
include_preprocessing: bool = True):
"""Captures the args shared across `save(...)` calls.
Args:
Expand All @@ -62,9 +63,11 @@ def __init__(self,
keys of a model output can also be renamed.
subdirectory: An optional subdirectory, if set: models are exported to
`os.path.join(export_dir, subdirectory).`
include_preprocessing: Whether to include any `preprocess_model.`
"""
self._output_names = output_names
self._subdirectory = subdirectory
self._include_preprocessing = include_preprocessing

def save(self,
preprocess_model: Optional[tf.keras.Model],
Expand All @@ -82,7 +85,7 @@ def save(self,
model: A `tf.keras.Model` to save.
export_dir: A destination directory for the model.
"""
if preprocess_model is not None:
if preprocess_model is not None and self._include_preprocessing:
model = model_utils.chain_first_output(preprocess_model, model)
if self._output_names is not None:
output = _rename_output(model.output, self._output_names)
Expand All @@ -99,18 +102,21 @@ def __init__(self,
submodule_name: str,
*,
output_names: Optional[Any] = None,
subdirectory: Optional[str] = None):
subdirectory: Optional[str] = None,
include_preprocessing: bool = False):
"""Captures the args shared across `save(...)` calls.
Args:
submodule_name: The name of the submodule to export.
output_names: The names for output Tensor(s), see: `KerasModelExporter.`
subdirectory: An optional subdirectory, if set: submodules are exported
to `os.path.join(export_dir, subdirectory).`
include_preprocessing: Whether to include any `preprocess_model.`
"""
self._output_names = output_names
self._subdirectory = subdirectory
self._submodule_name = submodule_name
self._include_preprocessing = include_preprocessing

def save(self,
preprocess_model: tf.keras.Model,
Expand Down Expand Up @@ -144,6 +150,7 @@ def save(self,

exporter = KerasModelExporter(
output_names=self._output_names,
subdirectory=self._subdirectory)
subdirectory=self._subdirectory,
include_preprocessing=self._include_preprocessing)

exporter.save(preprocess_model, submodel, export_dir)

0 comments on commit d7a9659

Please sign in to comment.