Skip to content

Commit 67a23b8

Browse files
committed
Minor fixes
1 parent b32cb0e commit 67a23b8

File tree

5 files changed

+58
-7
lines changed

5 files changed

+58
-7
lines changed

benchmarks/model_benchmark/image_classification_benchmark.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
"""Image classification benchmark.
22
3-
This script runs image classification benchmark with "dogs vs cats" datasets. It
4-
supports the following 3 models:
3+
This script runs image classification benchmark with "dogs vs cats" datasets.
4+
It supports the following 3 models:
5+
56
- EfficientNetV2B0
67
- Xception
78
- ResNet50V2

keras_core/backend/common/global_state.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,9 @@ def get_global_setting(name, default=None, set_to_default=False):
3333
return attr
3434

3535

36-
@keras_core_export("keras_core.backend.clear_session")
36+
@keras_core_export(
37+
["keras_core.utils.clear_session", "keras_core.backend.clear_session"]
38+
)
3739
def clear_session():
3840
"""Resets all state generated by Keras.
3941

keras_core/optimizers/optimizer.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,7 @@
1010
BackendOptimizer = base_optimizer.BaseOptimizer
1111

1212

13-
keras_core_export(["keras_core.Optimizer", "keras_core.optimizers.Optimizer"])
14-
15-
13+
@keras_core_export(["keras_core.Optimizer", "keras_core.optimizers.Optimizer"])
1614
class Optimizer(BackendOptimizer):
1715
pass
1816

keras_core/saving/saving_api.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,63 @@
66
from keras_core.api_export import keras_core_export
77
from keras_core.saving import saving_lib
88
from keras_core.saving.legacy import legacy_h5_format
9+
from keras_core.utils import io_utils
910

1011
try:
1112
import h5py
1213
except ImportError:
1314
h5py = None
1415

1516

17+
@keras_core_export(
18+
["keras_core.saving.save_model", "keras_core.models.save_model"]
19+
)
20+
def save_model(model, filepath, overwrite=True):
21+
"""Saves a model as a `.keras` file.
22+
23+
Args:
24+
model: Keras model instance to be saved.
25+
filepath: `str` or `pathlib.Path` object. Path where to save the model.
26+
overwrite: Whether we should overwrite any existing model at the target
27+
location, or instead ask the user via an interactive prompt.
28+
29+
Example:
30+
31+
```python
32+
model = keras_core.Sequential(
33+
[
34+
keras_core.layers.Dense(5, input_shape=(3,)),
35+
keras_core.layers.Softmax(),
36+
],
37+
)
38+
model.save("model.keras")
39+
loaded_model = keras_core.saving.load_model("model.keras")
40+
x = keras.random.uniform((10, 3))
41+
assert np.allclose(model.predict(x), loaded_model.predict(x))
42+
```
43+
44+
Note that `model.save()` is an alias for `keras_core.saving.save_model()`.
45+
46+
The saved `.keras` file contains:
47+
48+
- The model's configuration (architecture)
49+
- The model's weights
50+
- The model's optimizer's state (if any)
51+
52+
Thus models can be reinstantiated in the exact same state.
53+
"""
54+
# If file exists and should not be overwritten.
55+
try:
56+
exists = os.path.exists(filepath)
57+
except TypeError:
58+
exists = False
59+
if exists and not overwrite:
60+
proceed = io_utils.ask_to_proceed_with_overwrite(filepath)
61+
if not proceed:
62+
return
63+
saving_lib.save_model(model, filepath)
64+
65+
1666
@keras_core_export(
1767
["keras_core.saving.load_model", "keras_core.models.load_model"]
1868
)

keras_core/saving/saving_lib.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def save_model(model, filepath, weights_format="h5"):
4242
4343
- JSON-based configuration file (config.json): Records of model, layer, and
4444
other trackables' configuration.
45-
- NPZ-based trackable state files, found in respective directories, such as
45+
- H5-based trackable state files, found in respective directories, such as
4646
model/states.npz, model/dense_layer/states.npz, etc.
4747
- Metadata file.
4848

0 commit comments

Comments
 (0)