Skip to content

Commit

Permalink
Minor fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
fchollet committed Jul 4, 2023
1 parent b32cb0e commit 67a23b8
Show file tree
Hide file tree
Showing 5 changed files with 58 additions and 7 deletions.
5 changes: 3 additions & 2 deletions benchmarks/model_benchmark/image_classification_benchmark.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
"""Image classification benchmark.
This script runs image classification benchmark with "dogs vs cats" datasets. It
supports the following 3 models:
This script runs image classification benchmark with "dogs vs cats" datasets.
It supports the following 3 models:
- EfficientNetV2B0
- Xception
- ResNet50V2
Expand Down
4 changes: 3 additions & 1 deletion keras_core/backend/common/global_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,9 @@ def get_global_setting(name, default=None, set_to_default=False):
return attr


@keras_core_export("keras_core.backend.clear_session")
@keras_core_export(
["keras_core.utils.clear_session", "keras_core.backend.clear_session"]
)
def clear_session():
"""Resets all state generated by Keras.
Expand Down
4 changes: 1 addition & 3 deletions keras_core/optimizers/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,7 @@
BackendOptimizer = base_optimizer.BaseOptimizer


keras_core_export(["keras_core.Optimizer", "keras_core.optimizers.Optimizer"])


@keras_core_export(["keras_core.Optimizer", "keras_core.optimizers.Optimizer"])
class Optimizer(BackendOptimizer):
pass

Expand Down
50 changes: 50 additions & 0 deletions keras_core/saving/saving_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,63 @@
from keras_core.api_export import keras_core_export
from keras_core.saving import saving_lib
from keras_core.saving.legacy import legacy_h5_format
from keras_core.utils import io_utils

try:
import h5py
except ImportError:
h5py = None


@keras_core_export(
["keras_core.saving.save_model", "keras_core.models.save_model"]
)
def save_model(model, filepath, overwrite=True):
"""Saves a model as a `.keras` file.
Args:
model: Keras model instance to be saved.
filepath: `str` or `pathlib.Path` object. Path where to save the model.
overwrite: Whether we should overwrite any existing model at the target
location, or instead ask the user via an interactive prompt.
Example:
```python
model = keras_core.Sequential(
[
keras_core.layers.Dense(5, input_shape=(3,)),
keras_core.layers.Softmax(),
],
)
model.save("model.keras")
loaded_model = keras_core.saving.load_model("model.keras")
x = keras.random.uniform((10, 3))
assert np.allclose(model.predict(x), loaded_model.predict(x))
```
Note that `model.save()` is an alias for `keras_core.saving.save_model()`.
The saved `.keras` file contains:
- The model's configuration (architecture)
- The model's weights
- The model's optimizer's state (if any)
Thus models can be reinstantiated in the exact same state.
"""
# If file exists and should not be overwritten.
try:
exists = os.path.exists(filepath)
except TypeError:
exists = False
if exists and not overwrite:
proceed = io_utils.ask_to_proceed_with_overwrite(filepath)
if not proceed:
return
saving_lib.save_model(model, filepath)


@keras_core_export(
["keras_core.saving.load_model", "keras_core.models.load_model"]
)
Expand Down
2 changes: 1 addition & 1 deletion keras_core/saving/saving_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def save_model(model, filepath, weights_format="h5"):
- JSON-based configuration file (config.json): Records of model, layer, and
other trackables' configuration.
- NPZ-based trackable state files, found in respective directories, such as
- H5-based trackable state files, found in respective directories, such as
model/states.npz, model/dense_layer/states.npz, etc.
- Metadata file.
Expand Down

0 comments on commit 67a23b8

Please sign in to comment.